本文說明如何使用 XLA FFI 程式庫編寫及使用 XLA 自訂呼叫。自訂呼叫是一種機制,可在 HLO 模組中向 XLA 編譯器 (在編譯時) 說明外部「作業」,而 XLA FFI 則是一種機制,可向 XLA 註冊這類作業的實作 (在執行階段)。FFI 是「外部函式介面」的縮寫,是一組 C API,可為 XLA 定義二進位介面 (ABI),以便呼叫以其他程式設計語言編寫的外部程式碼。XLA 提供以 C++ 編寫的 XLA FFI 專屬繫結,可向使用者隱藏基礎 C API 的所有低層級詳細資料。
JAX + XLA 自訂呼叫
如需將自訂呼叫和 XLA FFI 與 JAX 整合的端對端範例,請參閱 JAX 說明文件。
XLA FFI 繫結
XLA FFI 繫結是自訂呼叫簽章的編譯時間規格:自訂呼叫引數、屬性和其類型,以及透過執行環境傳遞的其他參數 (即 GPU 後端的 GPU 串流)。XLA FFI 繫結可繫結至任何 C++ 可呼叫項目 (函式指標、lambda 等),並使用相容的 operator() 簽章。建構的處理常式會解碼 XLA FFI 呼叫框架 (由穩定的 C API 定義)、檢查所有參數的類型,並將解碼結果轉送至使用者定義的回呼。
XLA FFI 繫結大量採用範本中繼程式設計,以便將建構的處理常式編譯為最有效率的機器碼。每個自訂呼叫參數的執行時間負擔,大約是幾奈秒。
以範本特化形式實作的 XLA FFI 自訂點,使用者可以定義如何解碼自訂型別,也就是說,可以為使用者定義的 enum class 型別定義自訂解碼。
從自訂呼叫傳回錯誤
自訂呼叫實作項目必須傳回 xla::ffi::Error 值,向 XLA 執行階段發出成功或錯誤信號。這與 absl::Status 類似,且具有相同的錯誤代碼集。我們不會使用 absl::Status,因為它沒有穩定的 ABI,而且在動態載入的自訂呼叫程式庫和 XLA 本身之間傳遞時並不安全。
// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
[]() { return Error(ErrorCode::kInternal, "Oops!"); });
// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
[]() { return Error::Success(); });
緩衝區引數和結果
XLA 會使用目的地傳遞樣式來處理結果:自訂呼叫 (或任何其他 XLA 作業) 不會為結果分配記憶體,而是寫入 XLA 執行階段傳遞的目的地。XLA 會使用靜態緩衝區指派作業,並在編譯時根據所有值的即時範圍分配緩衝區。
傳遞至 FFI 處理常式的結果會包裝在 Result<T> 範本中,該範本具有類似指標的語意:operator-> 可存取基礎參數。
AnyBuffer 引數和結果可存取任何資料類型的自訂呼叫緩衝區參數。如果自訂呼叫具有適用於多種資料類型的泛型實作,且自訂呼叫實作會根據資料類型執行執行階段分派作業,這就很有用。AnyBuffer 可存取緩衝區資料型別、維度和緩衝區指標。
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
[](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
void* arg_data = arg.untyped_data();
void* res_data = res->untyped_data();
return Error::Success();
});
受限緩衝區引數和結果
Buffer 可用於對緩衝區資料型別和維度數量新增限制,如果執行階段引數與 FFI 處理常式簽章不符,處理常式會自動檢查並向 XLA 執行階段傳回錯誤。
// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
[](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
[](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
可變引數和結果
如果自訂呼叫的不同例項中,引數和結果的數量可能不同,則可在執行階段使用 RemainingArgs 和 RemainingRets 解碼。
auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
[](RemainingArgs args, RemainingRets results) -> Error {
ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);
if (!arg.has_value()) {
return Error(ErrorCode::kInternal, arg.error());
}
if (!res.has_value()) {
return Error(ErrorCode::kInternal, res.error());
}
return Error::Success();
});
可在一般引數和結果後宣告可變引數和結果,但可變引數後繫結一般引數和結果是違法的。
auto handler =
Ffi::Bind()
.Arg<AnyBuffer>()
.RemainingArgs()
.Ret<AnyBuffer>()
.RemainingRets()
.To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
RemainingRets results) -> Error { return Error::Success(); });
屬性
XLA FFI 支援自動解碼以 custom_call backend_config 形式傳遞的 mlir::DictionaryAttr,並將其做為 FFI 處理常式引數。
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
i32 = 42 : i32,
str = "string"
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
在這個範例中,自訂呼叫有一個緩衝區引數和兩個屬性,而 XLA FFI 可以自動解碼並傳遞至使用者定義的可呼叫函式。
auto handler = Ffi::Bind()
.Arg<BufferR0<F32>>()
.Attr<int32_t>("i32")
.Attr<std::string_view>("str")
.To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
return Error::Success();
});
使用者定義的列舉屬性
XLA FFI 可以自動將整數 MLIR 屬性解碼為使用者定義的列舉。列舉類別必須具有相同的基礎整數型別,且解碼必須向 XLA FFI 明確註冊。
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
command = 0 : i32
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
kAdd = 0,
kMul = 1,
};
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);
auto handler = Ffi::Bind().Attr<Command>("command").To(
[](Command command) -> Error { return Error::Success(); });
繫結所有自訂通話屬性
您可以將所有自訂通話屬性當做字典存取,並在執行階段延遲解碼所需的屬性。
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
return Error::Success();
});
使用者定義的結構體屬性
XLA FFI 可將字典屬性解碼為使用者定義的結構體。
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
range = { lo = 0 : i64, hi = 42 : i64 }
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
在上述範例中,range 是 mlir::DictionaryAttr 屬性,且可自動解碼為 C++ 結構體,不必依名稱存取字典欄位。解碼必須使用 XLA_FFI_REGISTER_STRUCT_ATTR_DECODING 巨集明確註冊 (幕後會在 ::xla::ffi 命名空間中定義範本特化,因此巨集必須新增至全域命名空間)。
struct Range {
int64_t lo;
int64_t hi;
};
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
StructMember<int64_t>("hi"));
auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
return Error::Success();
});
自訂屬性可以從字典載入,就像任何其他屬性一樣。在以下範例中,所有自訂通話屬性都會解碼為 Dictionary,且可依名稱存取 range。
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<Range> range = attrs.get<Range>("range");
return Error::Success();
});
在 CPU 上建立自訂呼叫
您可以透過 XLA 的用戶端 API 建立代表自訂呼叫的 HLO 指令。舉例來說,下列程式碼使用自訂呼叫,在 CPU 上計算 A[i] = B[i %
128]+ C[i]。(當然可以,而且應該這麼做!- 執行此操作
使用一般 HLO。)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
在 GPU 上建立自訂呼叫
使用 XLA FFI 註冊 GPU 自訂呼叫幾乎相同,唯一不同的是,您需要要求基礎平台串流 (CUDA 或 ROCM 串流),才能在裝置上啟動核心。以下是 CUDA 範例,可執行與上述 CPU 程式碼相同的運算 (A[i] = B[i % 128] + C[i])。
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
in0.data, in1.data, out->data);
}
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<CUstream>>()
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"CUDA", handler);
請注意,GPU 自訂呼叫函式仍是在 CPU 上執行的函式。do_custom_call CPU 函式負責在 GPU 上將工作加入佇列。這裡會啟動 CUDA 核心,但也可以執行其他動作,
例如呼叫 cuBLAS。
引數和結果也位於主機上,資料成員則包含指向裝置 (即 GPU) 記憶體的指標。傳遞至自訂呼叫處理常式的緩衝區具有基礎裝置緩衝區的形狀,因此自訂呼叫可以從中計算核心啟動參數。
將元組傳遞至自訂呼叫
請參考下列自訂呼叫。
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);
在 CPU 和 GPU 上,元組在記憶體中會以指標陣列表示。當 XLA 使用元組引數或結果呼叫自訂呼叫時,會將其扁平化,並以一般緩衝區引數或結果的形式傳遞。
以暫時緩衝區形式輸出元組
自訂呼叫的元組輸入內容很方便,但並非必要。如果我們不支援自訂呼叫的元組輸入,您一律可以使用 get-tuple-element 解壓縮元組,再將其傳遞至自訂呼叫。
另一方面,元組 outputs 可讓您執行原本無法執行的操作。
使用元組輸出的顯而易見原因,是自訂呼叫 (或任何其他 XLA 作業) 會透過元組輸出傳回多個獨立陣列。
但較不明顯的是,元組輸出也是為自訂呼叫暫時記憶體提供記憶體的方式。可以,輸出可以代表暫時緩衝區。舉例來說,輸出緩衝區具有作業可寫入的屬性,且在寫入後可從中讀取。這正是您對暫時緩衝區的期望。
在上述範例中,假設我們想將 F32[1024] 用做暫時緩衝區。
接著,我們會像上述一樣編寫 HLO,但不會讀取自訂呼叫輸出內容的元組索引 1。