XLA 自訂呼叫

本文說明如何使用 XLA FFI 程式庫編寫及使用 XLA 自訂呼叫。自訂呼叫是一種機制,可在 HLO 模組中向 XLA 編譯器 (在編譯時) 說明外部「作業」,而 XLA FFI 則是一種機制,可向 XLA 註冊這類作業的實作 (在執行階段)。FFI 是「外部函式介面」的縮寫,是一組 C API,可為 XLA 定義二進位介面 (ABI),以便呼叫以其他程式設計語言編寫的外部程式碼。XLA 提供以 C++ 編寫的 XLA FFI 專屬繫結,可向使用者隱藏基礎 C API 的所有低層級詳細資料。

JAX + XLA 自訂呼叫

如需將自訂呼叫和 XLA FFI 與 JAX 整合的端對端範例,請參閱 JAX 說明文件

XLA FFI 繫結

XLA FFI 繫結是自訂呼叫簽章的編譯時間規格:自訂呼叫引數、屬性和其類型,以及透過執行環境傳遞的其他參數 (即 GPU 後端的 GPU 串流)。XLA FFI 繫結可繫結至任何 C++ 可呼叫項目 (函式指標、lambda 等),並使用相容的 operator() 簽章。建構的處理常式會解碼 XLA FFI 呼叫框架 (由穩定的 C API 定義)、檢查所有參數的類型,並將解碼結果轉送至使用者定義的回呼。

XLA FFI 繫結大量採用範本中繼程式設計,以便將建構的處理常式編譯為最有效率的機器碼。每個自訂呼叫參數的執行時間負擔,大約是幾奈秒。

以範本特化形式實作的 XLA FFI 自訂點,使用者可以定義如何解碼自訂型別,也就是說,可以為使用者定義的 enum class 型別定義自訂解碼。

從自訂呼叫傳回錯誤

自訂呼叫實作項目必須傳回 xla::ffi::Error 值,向 XLA 執行階段發出成功或錯誤信號。這與 absl::Status 類似,且具有相同的錯誤代碼集。我們不會使用 absl::Status,因為它沒有穩定的 ABI,而且在動態載入的自訂呼叫程式庫和 XLA 本身之間傳遞時並不安全。

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

緩衝區引數和結果

XLA 會使用目的地傳遞樣式來處理結果:自訂呼叫 (或任何其他 XLA 作業) 不會為結果分配記憶體,而是寫入 XLA 執行階段傳遞的目的地。XLA 會使用靜態緩衝區指派作業,並在編譯時根據所有值的即時範圍分配緩衝區。

傳遞至 FFI 處理常式的結果會包裝在 Result<T> 範本中,該範本具有類似指標的語意:operator-> 可存取基礎參數。

AnyBuffer 引數和結果可存取任何資料類型的自訂呼叫緩衝區參數。如果自訂呼叫具有適用於多種資料類型的泛型實作,且自訂呼叫實作會根據資料類型執行執行階段分派作業,這就很有用。AnyBuffer 可存取緩衝區資料型別、維度和緩衝區指標。

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

受限緩衝區引數和結果

Buffer 可用於對緩衝區資料型別和維度數量新增限制,如果執行階段引數與 FFI 處理常式簽章不符,處理常式會自動檢查並向 XLA 執行階段傳回錯誤。

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

可變引數和結果

如果自訂呼叫的不同例項中,引數和結果的數量可能不同,則可在執行階段使用 RemainingArgsRemainingRets 解碼。

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

可在一般引數和結果後宣告可變引數和結果,但可變引數後繫結一般引數和結果是違法的。

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

屬性

XLA FFI 支援自動解碼以 custom_call backend_config 形式傳遞的 mlir::DictionaryAttr,並將其做為 FFI 處理常式引數。

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

在這個範例中,自訂呼叫有一個緩衝區引數和兩個屬性,而 XLA FFI 可以自動解碼並傳遞至使用者定義的可呼叫函式。

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

使用者定義的列舉屬性

XLA FFI 可以自動將整數 MLIR 屬性解碼為使用者定義的列舉。列舉類別必須具有相同的基礎整數型別,且解碼必須向 XLA FFI 明確註冊。

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

繫結所有自訂通話屬性

您可以將所有自訂通話屬性當做字典存取,並在執行階段延遲解碼所需的屬性。

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

使用者定義的結構體屬性

XLA FFI 可將字典屬性解碼為使用者定義的結構體。

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

在上述範例中,rangemlir::DictionaryAttr 屬性,且可自動解碼為 C++ 結構體,不必依名稱存取字典欄位。解碼必須使用 XLA_FFI_REGISTER_STRUCT_ATTR_DECODING 巨集明確註冊 (幕後會在 ::xla::ffi 命名空間中定義範本特化,因此巨集必須新增至全域命名空間)。

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

自訂屬性可以從字典載入,就像任何其他屬性一樣。在以下範例中,所有自訂通話屬性都會解碼為 Dictionary,且可依名稱存取 range

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

在 CPU 上建立自訂呼叫

您可以透過 XLA 的用戶端 API 建立代表自訂呼叫的 HLO 指令。舉例來說,下列程式碼使用自訂呼叫,在 CPU 上計算 A[i] = B[i % 128]+ C[i]。(當然可以,而且應該這麼做!- 執行此操作 使用一般 HLO。)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

在 GPU 上建立自訂呼叫

使用 XLA FFI 註冊 GPU 自訂呼叫幾乎相同,唯一不同的是,您需要要求基礎平台串流 (CUDA 或 ROCM 串流),才能在裝置上啟動核心。以下是 CUDA 範例,可執行與上述 CPU 程式碼相同的運算 (A[i] = B[i % 128] + C[i])。

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

請注意,GPU 自訂呼叫函式仍是在 CPU 上執行的函式do_custom_call CPU 函式負責在 GPU 上將工作加入佇列。這裡會啟動 CUDA 核心,但也可以執行其他動作, 例如呼叫 cuBLAS。

引數和結果也位於主機上,資料成員則包含指向裝置 (即 GPU) 記憶體的指標。傳遞至自訂呼叫處理常式的緩衝區具有基礎裝置緩衝區的形狀,因此自訂呼叫可以從中計算核心啟動參數。

將元組傳遞至自訂呼叫

請參考下列自訂呼叫。

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

在 CPU 和 GPU 上,元組在記憶體中會以指標陣列表示。當 XLA 使用元組引數或結果呼叫自訂呼叫時,會將其扁平化,並以一般緩衝區引數或結果的形式傳遞。

以暫時緩衝區形式輸出元組

自訂呼叫的元組輸入內容很方便,但並非必要。如果我們不支援自訂呼叫的元組輸入,您一律可以使用 get-tuple-element 解壓縮元組,再將其傳遞至自訂呼叫。

另一方面,元組 outputs 可讓您執行原本無法執行的操作。

使用元組輸出的顯而易見原因,是自訂呼叫 (或任何其他 XLA 作業) 會透過元組輸出傳回多個獨立陣列。

但較不明顯的是,元組輸出也是為自訂呼叫暫時記憶體提供記憶體的方式。可以,輸出可以代表暫時緩衝區。舉例來說,輸出緩衝區具有作業可寫入的屬性,且在寫入後可從中讀取。這正是您對暫時緩衝區的期望。

在上述範例中,假設我們想將 F32[1024] 用做暫時緩衝區。 接著,我們會像上述一樣編寫 HLO,但不會讀取自訂呼叫輸出內容的元組索引 1。