Пользовательские звонки XLA

В этом документе описывается, как писать и использовать пользовательские вызовы XLA с использованием библиотеки XLA FFI. Пользовательский вызов — это механизм описания внешней «операции» в модуле HLO для компилятора XLA (во время компиляции), а XLA FFI — это механизм регистрации реализации таких операций с помощью XLA (во время выполнения). FFI означает «интерфейс внешней функции» и представляет собой набор API-интерфейсов C, которые определяют двоичный интерфейс (ABI) для вызова XLA во внешний код, написанный на других языках программирования. XLA предоставляет привязки только для заголовков для XLA FFI, написанного на C++, что скрывает от конечного пользователя все низкоуровневые детали базовых API C.

Создать собственный вызов на ЦП

Вы можете создать инструкцию HLO, которая представляет собой пользовательский вызов через клиентский API XLA. Например, следующий код использует специальный вызов для вычисления A[i] = B[i % 128]+ C[i] на ЦП. (Конечно, вы можете – и должны! – сделать это с помощью обычного HLO.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

Создайте собственный вызов на графическом процессоре

Пользовательская регистрация вызовов графического процессора с помощью XLA FFI практически идентична, с той лишь разницей, что для графического процессора вам необходимо запросить поток базовой платформы (поток CUDA или ROCM), чтобы иметь возможность запускать ядро ​​на устройстве. Вот пример CUDA, который выполняет те же вычисления ( A[i] = B[i % 128] + C[i] ), что и приведенный выше код ЦП.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

Прежде всего обратите внимание, что пользовательская функция вызова графического процессора по-прежнему является функцией, выполняемой на ЦП . Функция CPU do_custom_call отвечает за постановку в очередь работы на графическом процессоре. Здесь он запускает ядро ​​CUDA, но может делать и что-то еще, например вызывать cuBLAS.

Аргументы и результаты также находятся на хосте, а элемент данных содержит указатель на память устройства (т. е. графического процессора). Буферы, передаваемые пользовательскому обработчику вызовов, имеют форму базовых буферов устройства, поэтому пользовательский вызов может вычислять на их основе параметры запуска ядра.

Передача кортежей в пользовательские вызовы

Рассмотрим следующий пользовательский вызов.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

Как на процессоре, так и на графическом процессоре кортеж представлен в памяти как массив указателей. Когда XLA вызывает пользовательские вызовы с аргументами или результатами кортежа, он выравнивает их и передает как обычные аргументы или результаты буфера.

Кортеж выводится как временные буферы

Ввод кортежей в пользовательские вызовы удобен, но не является строго необходимым. Если бы мы не поддерживали входные данные кортежей для пользовательских вызовов, вы всегда могли бы распаковать кортежи с помощью get-tuple-element перед передачей их пользовательскому вызову.

С другой стороны, выходные данные кортежа позволяют вам делать то, что иначе вы бы не смогли.

Очевидная причина использования выходных данных кортежа заключается в том, что выходные данные кортежа — это то, как пользовательский вызов (или любая другая операция XLA) возвращает несколько независимых массивов.

Но менее очевидно, что вывод кортежа также является способом выделения временной памяти для вашего пользовательского вызова. Да, выходные данные могут представлять собой временный буфер. Учтите, что выходной буфер имеет свойство, позволяющее оператору писать в него и читать из него после того, как он был записан. Это именно то, что вы хотите от временного буфера.

Предположим, в приведенном выше примере мы хотим использовать F32[1024] в качестве временного буфера. Тогда мы бы написали HLO так же, как указано выше, и просто никогда не читали бы кортеж с индексом 1 вывода пользовательского вызова.