Chamadas personalizadas do XLA

Este documento descreve como criar e usar chamadas personalizadas do XLA usando o XLA FFI. biblioteca. A chamada personalizada é um mecanismo para descrever uma "operação" externa no módulo HLO para o compilador XLA (no momento da compilação), e a FFI XLA é um mecanismo para registrar a implementação dessas operações com o XLA (no momento da execução). FFI significa "interface de função externa" e é um conjunto de APIs C que definem uma interface binária (ABI) para que a XLA chame o código externo escrito em outras linguagens de programação. O XLA fornece vinculações somente de cabeçalho para a FFI do XLA escrita em C++, o que oculta todos os detalhes de baixo nível das APIs C subjacentes do usuário final.

Chamadas personalizadas JAX + XLA

Consulte a documentação do JAX para exemplos completos de integração de chamadas personalizadas e XLA FFI com JAX.

Vinculação FFI do XLA

A vinculação FFI do XLA é uma especificação no momento da compilação da assinatura de chamada personalizada: argumentos de chamada personalizados, atributos e tipos deles, além de outros parâmetros transmitidos pelo contexto de execução (ou seja, transmissão de GPU para o back-end da GPU). A descoberta de FFI do XLA pode ser vinculada a qualquer elemento invocável em C++ (ponteiro de função, lambda etc.) com assinatura operator() compatível. O gerenciador construído decodifica a chamada FFI do XLA frame (definido pela API C estável), verificar todos os parâmetros de tipo e encaminhar decodificados para o callback definido pelo usuário.

A vinculação FFI do XLA depende muito da metaprogramação de modelo para poder compilar o manipulador construído para o código de máquina mais eficiente. Tempo de execução as sobrecargas de trabalho são de alguns nanossegundos para cada chamada personalizada .

pontos de personalização de XLA FFI implementados como especializações de modelo e os usuários podem definir como decodificar seus tipos personalizados, ou seja, é possível para definir a decodificação personalizada para tipos de enum class definidos pelo usuário.

Como retornar erros de chamadas personalizadas

As implementações de chamada personalizada precisam retornar o valor xla::ffi::Error para sinalizar sucesso ou erro para o ambiente de execução do XLA. Ela é semelhante a absl::Status e tem o mesmo conjunto de códigos de erro. Não usamos absl::Status porque ele usa têm uma ABI estável, e não seria seguro passá-la entre a biblioteca de chamadas personalizada carregada e o próprio XLA.

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

Argumentos e resultados do buffer

A XLA usa o estilo de transmissão de destino para resultados: chamadas personalizadas (ou qualquer outra operação da XLA) não alocam memória para resultados. Em vez disso, elas gravam em destinos transmitidos pelo ambiente de execução da XLA. O XLA usa buffer estático e aloca buffers para todos os valores com base nos intervalos ativos em tempo de compilação.

Os resultados são transmitidos para gerenciadores de FFI agrupados em um modelo Result<T>, que tem uma semântica semelhante a ponteiro: operator-> dá acesso ao objeto .

Os argumentos e resultados AnyBuffer dão acesso a parâmetros personalizados do buffer de chamada. de qualquer tipo de dados. Isso é útil quando a chamada personalizada tem uma implementação genérica que funcione para vários tipos de dados, e a implementação de chamadas personalizadas não é de envio com base no tipo de dados. AnyBuffer dá acesso aos dados do buffer. tipo, dimensões e um ponteiro para o próprio buffer.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any rank and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

Argumentos e resultados de buffer limitados

Buffer permite adicionar restrições ao tipo e à classificação de dados do buffer. Elas serão verificadas automaticamente pelo gerenciador e retornarão um erro ao tempo de execução da XLA, se os argumentos de execução não corresponderem à assinatura do gerenciador FFI.

// Buffers of any rank and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of rank 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

Argumentos e resultados variáveis

Se o número de argumentos e resultado puder ser diferente em instâncias diferentes de uma chamada personalizada, eles podem ser decodificados no ambiente de execução usando RemainingArgs e RemainingRets.

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

Argumentos e resultados variados podem ser declarados após argumentos regulares e resultados, no entanto, vincular argumentos regulares e resultados após o variado ilegal.

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

Atributos

O XLA FFI oferece suporte à decodificação automática de mlir::DictionaryAttr transmitido como um custom_call backend_config em argumentos de gerenciador FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

Neste exemplo, a chamada personalizada tem um único argumento de buffer e dois atributos, e A FFI do XLA pode decodificá-las automaticamente e passá-las para o chamável definido pelo usuário.

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

Atributos de enumeração definidos pelo usuário

A XLA FFI pode decodificar automaticamente atributos MLIR integrais em enumerações definidas pelo usuário. A classe de enumeração precisa ter o mesmo tipo integral e a decodificação precisa ser registrada explicitamente com o XLA FFI.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

Vincular todos os atributos de chamada personalizados

É possível ter acesso a todos os atributos de chamada personalizados como um dicionário e decodificam lentamente apenas os atributos necessários no ambiente de execução.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

Atributos de estrutura definidos pelo usuário

A FFI do XLA pode decodificar atributos do dicionário em estruturas definidas pelo usuário.

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

No exemplo acima, range é um atributo mlir::DictionaryAttr. Em vez de acessar campos do dicionário pelo nome, ele pode ser automaticamente decodificado como um struct C++. A decodificação precisa ser registrada explicitamente com uma macro XLA_FFI_REGISTER_STRUCT_ATTR_DECODING. Por trás dos bastidores, ela define uma especialização de modelo no namespace ::xla::ffi. Portanto, a macro precisa ser adicionada ao namespace global.

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("i64"),
                                             StructMember<int64_t>("i64"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

Os atributos personalizados podem ser carregados de um dicionário, assim como qualquer outro atributo. No exemplo abaixo, todos os atributos de chamada personalizados decodificados como Dictionary e um range podem ser acessados por nome.

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

Criar uma chamada personalizada na CPU

É possível criar uma instrução HLO que represente uma chamada personalizada pelo cliente do XLA. API. Por exemplo, o código a seguir usa uma chamada personalizada para calcular A[i] = B[i % 128]+ C[i] na CPU. Claro que você pode e deve! – faça isso com o HLO normal.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Criar uma chamada personalizada na GPU

O registro de chamada personalizado da GPU com XLA FFI é quase idêntico, o único A diferença é que, para a GPU, você precisa solicitar um stream de plataforma subjacente (fluxo CUDA ou ROCM) para iniciar o kernel no dispositivo. Aqui está um CUDA exemplo que faz o mesmo cálculo (A[i] = B[i % 128] + C[i]) que a CPU código acima.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Observe primeiro que a função de chamada personalizada da GPU ainda é uma função executada na CPU. A função de CPU do_custom_call é responsável por enfileirar o trabalho na GPU. Aqui, ele inicia um kernel CUDA, mas também pode fazer outra coisa, como chamar cuBLAS.

Os argumentos e os resultados também ficam no host, e o membro de dados contém um ponteiro para a memória do dispositivo (ou seja, GPU). Os buffers transmitidos para o gerenciador de chamadas personalizado têm a forma dos buffers do dispositivo subjacente. Assim, a chamada personalizada pode calcular os parâmetros de inicialização do kernel a partir deles.

Como transmitir tuplas para chamadas personalizadas

Considere a seguinte chamada personalizada.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Na CPU e na GPU, uma tupla é representada na memória como uma matriz de ponteiros. Quando o XLA chama chamadas personalizadas com argumentos ou resultados de tupla, ele as nivela e passa como resultados ou argumentos de buffer regular.

Tuplas saídas como buffers de temperatura

As entradas de tupla para chamadas personalizadas são convenientes, mas não são estritamente necessárias. Se não houvesse suporte a entradas de tupla para chamadas personalizadas, você poderia descompactar as tuplas usando o get-tuple-element antes de transmiti-las para a chamada personalizada.

Por outro lado, as saídas de tuplas permitem fazer coisas que não poderiam ser de outra forma.

A razão óbvia para ter saídas de tupla é que elas são como uma chamada personalizada (ou qualquer outra operação XLA) retorna várias matrizes independentes.

Mas, menos obviamente, uma saída de tupla também é uma maneira de fornecer memória temporária à chamada personalizada. Sim, uma saída pode representar um buffer temporário. Considere que um buffer de saída tem a propriedade que a operação pode gravar nele e pode ler dele depois que ele for gravado. É exatamente isso que você quer de um buffer temporário.

No exemplo acima, suponha que queremos usar o F32[1024] como um buffer temporário. A seguir, gravaríamos o HLO como mostrado acima e simplesmente nunca leríamos o índice de tupla 1. da saída da chamada personalizada.