Chamadas personalizadas do XLA

Este documento descreve como escrever e usar chamadas personalizadas da XLA usando a biblioteca XLA FFI. A chamada personalizada é um mecanismo para descrever uma "operação" externa no módulo HLO para o compilador XLA (no momento da compilação), e a FFI do XLA é um mecanismo para registrar a implementação dessas operações com o XLA (no momento da execução). FFI significa "interface de função externa" e é um conjunto de APIs C que definem uma interface binária (ABI) para o XLA chamar código externo escrito em outras linguagens de programação. O XLA fornece vinculações somente de cabeçalho para FFI do XLA gravadas em C++, que ocultam todos os detalhes de baixo nível das APIs C subjacentes do usuário final.

Chamadas personalizadas do JAX + XLA

Consulte a documentação do JAX para exemplos completos de integração de chamadas personalizadas e XLA FFI com JAX.

Vinculação de FFI do XLA

A vinculação FFI do XLA é uma especificação de tempo de compilação da assinatura de chamada personalizada: argumentos, atributos e tipos de chamada personalizada, além de parâmetros adicionais transmitidos pelo contexto de execução (por exemplo, fluxo de GPU para back-end de GPU). A vinculação da FFI do XLA pode ser vinculada a qualquer elemento chamável em C++ (ponteiro de função, lambda etc.) com uma assinatura operator() compatível. O manipulador construído decodifica o frame de chamada XLA FFI (definido pela API C estável), verifica o tipo de todos os parâmetros e encaminha os resultados decodificados para o callback definido pelo usuário.

A vinculação FFI do XLA depende muito da metaprogramação de modelo para compilar o manipulador construído no código de máquina mais eficiente. Os sobrecargas de tempo de execução são da ordem de alguns nanossegundos para cada parâmetro de chamada personalizada.

Pontos de personalização da FFI do XLA implementados como especializações de modelo, e os usuários podem definir como decodificar seus tipos personalizados, ou seja, é possível definir a decodificação personalizada para tipos enum class definidos pelo usuário.

Retornar erros de chamadas personalizadas

As implementações de chamadas personalizadas precisam retornar o valor xla::ffi::Error para sinalizar sucesso ou erro para o tempo de execução do XLA. É semelhante a absl::Status e tem o mesmo conjunto de códigos de erro. Não usamos absl::Status porque ele não tem uma ABI estável e seria inseguro transmiti-lo entre a biblioteca de chamadas personalizadas carregada dinamicamente e o próprio XLA.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

Argumentos e resultados de buffer

O XLA usa o estilo de transmissão de destino para resultados: chamadas personalizadas (ou qualquer outra operação do XLA) não alocam memória para resultados. Em vez disso, elas gravam em destinos transmitidos pelo tempo de execução do XLA. O XLA usa atribuição estática de buffer e aloca buffers para todos os valores com base nos intervalos ativos deles no momento da compilação.

Os resultados transmitidos para gerenciadores de FFI são encapsulados em um modelo Result<T>, que tem uma semântica semelhante a um ponteiro: operator-> dá acesso ao parâmetro subjacente.

Os argumentos e resultados de AnyBuffer dão acesso a parâmetros de buffer de chamada personalizados de qualquer tipo de dados. Isso é útil quando a chamada personalizada tem uma implementação genérica que funciona para vários tipos de dados, e a implementação da chamada personalizada faz o envio em tempo de execução com base no tipo de dados. AnyBuffer dá acesso ao tipo de dados, às dimensões e a um ponteiro para o buffer.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

Argumentos e resultados de buffer restritos

Buffer permite adicionar restrições ao tipo de dados do buffer e ao número de dimensões. Elas serão verificadas automaticamente pelo manipulador e retornarão um erro ao tempo de execução do XLA se os argumentos de tempo de execução não corresponderem à assinatura do manipulador de FFI.

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

Argumentos e resultados variádicos

Se o número de argumentos e o resultado puderem ser diferentes em várias instâncias de uma chamada personalizada, eles poderão ser decodificados no tempo de execução usando RemainingArgs e RemainingRets.

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

Argumentos e resultados variádicos podem ser declarados depois de argumentos e resultados regulares. No entanto, vincular argumentos e resultados regulares depois de um variádico é ilegal.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

Atributos

A FFI do XLA oferece suporte à decodificação automática de mlir::DictionaryAttr transmitidos como um custom_call backend_config em argumentos de manipulador de FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

Neste exemplo, a chamada personalizada tem um único argumento de buffer e dois atributos, e a FFI XLA pode decodificá-los automaticamente e transmitir para o elemento invocável definido pelo usuário.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

Atributos de enumeração definidos pelo usuário

A FFI do XLA pode decodificar automaticamente atributos integrais do MLIR em enums definidos pelo usuário. A classe de enumeração precisa ter o mesmo tipo integral subjacente, e a decodificação precisa ser registrada explicitamente com a FFI do XLA.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

Vinculando todos os atributos personalizados de chamada

É possível acessar todos os atributos de chamada personalizados como um dicionário e decodificar de forma lenta apenas os atributos necessários no tempo de execução.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

Atributos de struct definidos pelo usuário

A FFI do XLA pode decodificar atributos de dicionário em structs definidos pelo usuário.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

No exemplo acima, range é um atributo mlir::DictionaryAttr e, em vez de acessar campos de dicionário por nome, ele pode ser decodificado automaticamente como uma struct C++. A decodificação precisa ser registrada explicitamente com uma macro XLA_FFI_REGISTER_STRUCT_ATTR_DECODING. Nos bastidores, ela define uma especialização de modelo no namespace ::xla::ffi. Portanto, a macro precisa ser adicionada ao namespace global.

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

Os atributos personalizados podem ser carregados de um dicionário, assim como qualquer outro atributo. No exemplo abaixo, todos os atributos de chamada personalizados são decodificados como um Dictionary, e um range pode ser acessado por nome.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

Criar uma chamada personalizada na CPU

É possível criar uma instrução HLO que representa uma chamada personalizada usando a API do cliente do XLA. Por exemplo, o código a seguir usa uma chamada personalizada para calcular A[i] = B[i % 128]+ C[i] na CPU. É claro que você pode e deve! – faça isso com HLO regular.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Criar uma chamada personalizada na GPU

O registro de chamada personalizada de GPU com XLA FFI é quase idêntico. A única diferença é que, para GPU, você precisa pedir um fluxo de plataforma subjacente (fluxo CUDA ou ROCM) para iniciar o kernel no dispositivo. Confira um exemplo de CUDA que faz o mesmo cálculo (A[i] = B[i % 128] + C[i]) que o código de CPU acima.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Primeiro, observe que a função de chamada personalizada da GPU ainda é executada na CPU. A função de CPU do_custom_call é responsável por enfileirar o trabalho na GPU. Aqui, ele inicia um kernel CUDA, mas também pode fazer outra coisa, como chamar cuBLAS.

Os argumentos e resultados também ficam no host, e o membro de dados contém um ponteiro para a memória do dispositivo (ou seja, GPU). Os buffers transmitidos ao manipulador de chamadas personalizado têm o formato dos buffers do dispositivo subjacente. Assim, a chamada personalizada pode calcular os parâmetros de inicialização do kernel com base neles.

Como transmitir tuplas para chamadas personalizadas

Considere a seguinte chamada personalizada.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Na CPU e na GPU, uma tupla é representada na memória como uma matriz de ponteiros. Quando o XLA chama funções personalizadas com argumentos ou resultados de tupla, ele os simplifica e transmite como argumentos ou resultados de buffer regulares.

Tuplas de saída como buffers temporários

As tuplas de entrada para chamadas personalizadas são convenientes, mas não são estritamente necessárias. Se não oferecêssemos suporte a entradas de tupla para chamadas personalizadas, você sempre poderia descompactar as tuplas usando "get-tuple-element" antes de transmiti-las para a chamada personalizada.

Por outro lado, as saídas de tupla permitem fazer coisas que não seriam possíveis de outra forma.

O motivo óbvio para ter saídas de tupla é que elas são a maneira como uma chamada personalizada (ou qualquer outra operação do XLA) retorna várias matrizes independentes.

Mas, de forma menos óbvia, uma saída de tupla também é uma maneira de dar à sua memória temporária de chamada personalizada. Sim, uma saída pode representar um buffer temporário. Considere que um buffer de saída tem a propriedade de que a operação pode gravar nele e ler dele depois que ele é gravado. É exatamente isso que você quer de um buffer temporário.

No exemplo acima, suponha que você queira usar o F32[1024] como um buffer temporário. Em seguida, escreveríamos o HLO como acima e simplesmente nunca leríamos o índice de tupla 1 da saída da chamada personalizada.