XLA 自訂呼叫

本文件說明如何使用 XLA FFI 程式庫編寫及使用 XLA 自訂呼叫。自訂呼叫是一種機制,可在編譯時向 XLA 編譯器 (編譯時) 說明 HLO 模組中的外部「作業」,而 XLA FFI 是一種機制,可在執行階段向 XLA 註冊這類運算。FFI 代表「外部函式介面」,是一組 C API,用於定義 XLA 的二進位介面 (ABI),以便呼叫以其他程式設計語言編寫的外部程式碼。XLA 為以 C++ 編寫的 XLA FFI 提供僅限標頭的繫結,進而隱藏使用者的所有基礎 C API 低階詳細資料。

在 CPU 上建立自訂呼叫

您可以建立 HLO 指令,代表透過 XLA 的用戶端 API 進行自訂呼叫。例如,以下程式碼使用自訂呼叫在 CPU 上計算 A[i] = B[i % 128]+ C[i]。(當然,你可以,也應該!請使用一般的 HLO 操作)。

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

在 GPU 上建立自訂呼叫

XLA FFI 的 GPU 自訂呼叫註冊幾乎完全相同,唯一的差別在於,就 GPU 而言,您需要要求基礎平台串流 (CUDA 或 ROCM 串流) 才能在裝置上啟動核心。以下 CUDA 範例的運算 (A[i] = B[i % 128] + C[i]) 與上述 CPU 程式碼相同。

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

請注意,GPU 自訂呼叫函式仍是在 CPU 上執行的函式do_custom_call CPU 函式負責將 GPU 上的工作排入佇列。這裡會啟動 CUDA 核心,但也能執行其他操作,例如呼叫 cuBLAS。

引數和結果也會存在於主機上,且資料成員包含裝置 (例如 GPU) 記憶體的指標。傳遞至自訂呼叫處理常式的緩衝區具有基礎裝置緩衝區的形狀,因此自訂呼叫可以從這些緩衝區計算核心啟動參數。

傳送元組到自訂呼叫

請考慮下列自訂呼叫。

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

在 CPU 和 GPU 上,記憶體中的元組會以指標陣列的形式表示。當 XLA 使用元組引數呼叫自訂呼叫,或結果會合併時,這些呼叫會做為一般緩衝區引數或結果傳遞。

以臨時緩衝區的形式提供元組輸出內容

自訂呼叫的元組輸入雖然方便,但並非強制要求。如果我們不支援自訂呼叫的元組輸入值,您可以先使用 get-tuple-element 解開元組,再將其傳送至自訂呼叫。

另一方面,元組的「輸出」可讓您執行其他原先無法執行的作業。

使用元組輸出的明顯原因是,元組輸出是自訂呼叫 (或任何其他 XLA 運算) 傳回多個獨立陣列的方式。

但較不明顯的是,元組輸出也是提供自訂呼叫臨時記憶體的方法。可以,輸出內容代表臨時緩衝區。試想,輸出緩衝區具有運算能寫入的屬性,並且可以在寫入緩衝區後讀取。這就是您需要的暫存緩衝區。

在上述範例中,假設我們想要使用 F32[1024] 做為臨時緩衝區。接著,按照上述方式編寫 HLO,處理自訂呼叫輸出內容的元組索引 1 就絕不會被讀取。