이 문서에서는 XLA FFI 라이브러리를 사용하여 XLA 커스텀 호출을 작성하고 사용하는 방법을 설명합니다. 맞춤 호출은 컴파일 시간에 HLO 모듈의 외부 '작업'을 XLA 컴파일러에 설명하는 메커니즘이고 XLA FFI는 런타임에 이러한 작업의 구현을 XLA에 등록하는 메커니즘입니다. FFI는 'foreign function interface'의 약자로, XLA가 다른 프로그래밍 언어로 작성된 외부 코드를 호출할 수 있도록 바이너리 인터페이스 (ABI)를 정의하는 C API 집합입니다. XLA는 C++로 작성된 XLA FFI의 헤더 전용 바인딩을 제공하여 기본 C API의 모든 하위 수준 세부정보를 최종 사용자로부터 숨깁니다.
JAX + XLA 맞춤 호출
JAX와 커스텀 호출 및 XLA FFI를 통합하는 엔드 투 엔드 예시는 JAX 문서를 참고하세요.
XLA FFI 바인딩
XLA FFI 바인딩은 맞춤 호출 서명의 컴파일 시간 사양입니다. 맞춤 호출 인수, 속성 및 유형, 실행 컨텍스트를 통해 전달되는 추가 매개변수 (예: GPU 백엔드의 GPU 스트림) XLA FFI 바인딩은 호환되는 operator() 서명이 있는 모든 C++ 호출 가능 항목 (함수 포인터, 람다 등)에 바인딩할 수 있습니다. 생성된 핸들러는 안정적인 C API로 정의된 XLA FFI 호출 프레임을 디코딩하고 모든 매개변수의 유형을 확인하며 디코딩된 결과를 사용자 정의 콜백으로 전달합니다.
XLA FFI 바인딩은 템플릿 메타 프로그래밍을 사용하여 생성된 핸들러를 가장 효율적인 머신 코드로 컴파일합니다. 런타임 오버헤드는 각 맞춤 호출 매개변수에 대해 몇 나노초 순입니다.
템플릿 전문화로 구현된 XLA FFI 맞춤설정 포인트로, 사용자는 맞춤 유형을 디코딩하는 방법을 정의할 수 있습니다. 즉, 사용자 정의 enum class 유형의 맞춤 디코딩을 정의할 수 있습니다.
맞춤 호출에서 오류 반환
맞춤 호출 구현은 xla::ffi::Error 값을 반환하여 XLA 런타임에 성공 또는 오류를 알려야 합니다. absl::Status와 유사하며 동일한 오류 코드 집합이 있습니다. absl::Status는 안정적인 ABI가 없으며 동적으로 로드된 맞춤 호출 라이브러리와 XLA 자체 간에 전달하는 것이 안전하지 않으므로 사용하지 않습니다.
// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
[]() { return Error(ErrorCode::kInternal, "Oops!"); });
// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
[]() { return Error::Success(); });
버퍼 인수 및 결과
XLA는 결과에 대상 전달 스타일을 사용합니다. 맞춤 호출 (또는 기타 XLA 작업)은 결과에 메모리를 할당하지 않고 XLA 런타임에서 전달한 대상에 씁니다. XLA는 정적 버퍼 할당을 사용하고 컴파일 시간에 활성 범위를 기반으로 모든 값에 버퍼를 할당합니다.
포인터와 유사한 의미 체계가 있는 Result<T> 템플릿으로 래핑된 FFI 핸들러에 전달된 결과: operator->은 기본 매개변수에 대한 액세스 권한을 제공합니다.
AnyBuffer 인수와 결과를 통해 모든 데이터 유형의 맞춤 호출 버퍼 매개변수에 액세스할 수 있습니다. 이는 맞춤 호출에 여러 데이터 유형에 작동하는 일반 구현이 있고 맞춤 호출 구현이 데이터 유형에 따라 런타임 디스패치를 실행하는 경우에 유용합니다. AnyBuffer은 버퍼 데이터 유형, 차원, 버퍼 자체에 대한 포인터에 대한 액세스 권한을 제공합니다.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
[](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
void* arg_data = arg.untyped_data();
void* res_data = res->untyped_data();
return Error::Success();
});
제약이 있는 버퍼 인수 및 결과
Buffer를 사용하면 버퍼 데이터 유형과 차원 수에 제약 조건을 추가할 수 있으며, 런타임 인수가 FFI 핸들러 서명과 일치하지 않으면 핸들러에서 자동으로 확인하고 XLA 런타임에 오류를 반환합니다.
// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
[](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
[](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
float* arg_data = arg.typed_data();
float* res_data = res->typed_data();
return Error::Success();
});
가변 인수 및 결과
인수와 결과의 수가 맞춤 호출의 여러 인스턴스에서 다를 수 있는 경우 RemainingArgs 및 RemainingRets를 사용하여 런타임에 디코딩할 수 있습니다.
auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
[](RemainingArgs args, RemainingRets results) -> Error {
ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);
if (!arg.has_value()) {
return Error(ErrorCode::kInternal, arg.error());
}
if (!res.has_value()) {
return Error(ErrorCode::kInternal, res.error());
}
return Error::Success();
});
가변 인수와 결과는 일반 인수와 결과 뒤에 선언할 수 있지만 가변 인수 뒤에 일반 인수를 바인딩하는 것은 허용되지 않습니다.
auto handler =
Ffi::Bind()
.Arg<AnyBuffer>()
.RemainingArgs()
.Ret<AnyBuffer>()
.RemainingRets()
.To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
RemainingRets results) -> Error { return Error::Success(); });
속성
XLA FFI는 custom_call backend_config로 전달된 mlir::DictionaryAttr를 FFI 핸들러 인수로 자동 디코딩하는 것을 지원합니다.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
i32 = 42 : i32,
str = "string"
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
이 예에서 맞춤 호출에는 단일 버퍼 인수와 두 속성이 있으며 XLA FFI는 이를 자동으로 디코딩하여 사용자 정의 호출 가능 객체에 전달할 수 있습니다.
auto handler = Ffi::Bind()
.Arg<BufferR0<F32>>()
.Attr<int32_t>("i32")
.Attr<std::string_view>("str")
.To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
return Error::Success();
});
사용자 정의 열거형 속성
XLA FFI는 정수 MLIR 속성을 사용자 정의 enum으로 자동 디코딩할 수 있습니다. 열거형 클래스의 기본 정수 유형이 동일해야 하며 디코딩이 XLA FFI에 명시적으로 등록되어야 합니다.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
command = 0 : i32
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
kAdd = 0,
kMul = 1,
};
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);
auto handler = Ffi::Bind().Attr<Command>("command").To(
[](Command command) -> Error { return Error::Success(); });
모든 맞춤 통화 속성 바인딩
모든 맞춤 통화 속성에 딕셔너리로 액세스하고 런타임에 필요한 속성만 지연 디코딩할 수 있습니다.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
return Error::Success();
});
사용자 정의 구조체 속성
XLA FFI는 사전 속성을 사용자 정의 구조체로 디코딩할 수 있습니다.
%0 = "stablehlo.custom_call"(%arg0) {
call_target_name = "foo",
backend_config= {
range = { lo = 0 : i64, hi = 42 : i64 }
},
api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
위 예에서 range은 mlir::DictionaryAttr 속성이며, 이름으로 사전 필드에 액세스하는 대신 C++ 구조체로 자동 디코딩할 수 있습니다. 디코딩은 XLA_FFI_REGISTER_STRUCT_ATTR_DECODING 매크로를 사용하여 명시적으로 등록해야 합니다 (내부적으로 ::xla::ffi 네임스페이스에 템플릿 전문화를 정의하므로 매크로를 전역 네임스페이스에 추가해야 함).
struct Range {
int64_t lo;
int64_t hi;
};
XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
StructMember<int64_t>("hi"));
auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
return Error::Success();
});
커스텀 속성은 다른 속성과 마찬가지로 사전에서 로드할 수 있습니다. 아래 예시에서는 모든 맞춤 호출 속성이 Dictionary로 디코딩되며 range는 이름으로 액세스할 수 있습니다.
auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
ErrorOr<Range> range = attrs.get<Range>("range");
return Error::Success();
});
CPU에서 맞춤 호출 만들기
XLA의 클라이언트 API를 통해 맞춤 호출을 나타내는 HLO 명령어를 만들 수 있습니다. 예를 들어 다음 코드는 맞춤 호출을 사용하여 CPU에서 A[i] = B[i %
128]+ C[i]를 계산합니다. (물론 가능하며 그렇게 하는 것이 좋습니다. – 일반 HLO로 이 작업을 실행합니다.)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
/*opaque=*/"", /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}
// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;
// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
// Check that dimensions are compatible.
assert(out->dimensions[0] == d1 && "unexpected dimensions");
for (size_t i = 0; i < d1; ++i) {
out->data[i] = in0.data[i % d0] + in1.data[i];
}
}
// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Arg<Buffer>()
.Arg<Buffer>()
.Ret<Buffer>());
// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"Host", handler);
GPU에서 맞춤 호출 만들기
XLA FFI를 사용한 GPU 맞춤 호출 등록은 거의 동일하며, 유일한 차이점은 기기에서 커널을 실행하려면 기본 플랫폼 스트림(CUDA 또는 ROCM 스트림)을 요청해야 한다는 것입니다. 다음은 위의 CPU 코드와 동일한 계산 (A[i] = B[i % 128] + C[i])을 실행하는 CUDA 예입니다.
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
xla::ffi::Result<BufferF32> out) {
size_t d0 = in0.dimensions[0];
size_t d1 = in1.dimensions[0];
size_t d2 = out->dimensions[0];
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
in0.data, in1.data, out->data);
}
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
ffi::Ffi::Bind()
.Ctx<xla::ffi::PlatformStream<CUstream>>()
.Arg<BufferF32>()
.Arg<BufferF32>()
.Ret<BufferF32>());
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
"CUDA", handler);
먼저 GPU 맞춤 호출 함수는 여전히 CPU에서 실행되는 함수입니다. do_custom_call CPU 함수는 GPU에서 작업을 대기열에 추가하는 역할을 합니다. 여기서는 CUDA 커널을 실행하지만 cuBLAS를 호출하는 등 다른 작업을 할 수도 있습니다.
인수와 결과도 호스트에 있으며 데이터 멤버에는 기기 (즉, GPU) 메모리에 대한 포인터가 포함됩니다. 맞춤 호출 핸들러에 전달된 버퍼는 기본 기기 버퍼의 모양을 가지므로 맞춤 호출은 여기에서 커널 실행 매개변수를 계산할 수 있습니다.
맞춤 호출에 튜플 전달
다음 맞춤 호출을 고려해 보세요.
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);
CPU와 GPU 모두에서 튜플은 메모리에 포인터 배열로 표현됩니다. XLA가 튜플 인수 또는 결과로 맞춤 호출을 호출하면 이를 평면화하고 일반 버퍼 인수 또는 결과로 전달합니다.
임시 버퍼로 튜플 출력
커스텀 호출에 대한 튜플 입력은 편리하지만 엄격하게 필요한 것은 아닙니다. 맞춤 호출에 튜플 입력이 지원되지 않는 경우 맞춤 호출에 전달하기 전에 항상 get-tuple-element를 사용하여 튜플을 압축 해제할 수 있습니다.
반면 튜플 출력을 사용하면 그렇지 않으면 할 수 없는 작업을 할 수 있습니다.
튜플 출력이 있는 명백한 이유는 맞춤 호출 (또는 기타 XLA 작업)이 여러 독립 배열을 반환하는 방식이 튜플 출력이기 때문입니다.
하지만 덜 명시적으로 튜플 출력은 맞춤 호출에 임시 메모리를 제공하는 방법이기도 합니다. 예, 출력은 임시 버퍼를 나타낼 수 있습니다. 출력 버퍼에는 작업이 버퍼에 쓸 수 있고 버퍼에 쓴 후에는 버퍼에서 읽을 수 있다는 속성이 있습니다. 이것이 바로 임시 버퍼에서 원하는 것입니다.
위 예에서 F32[1024]를 임시 버퍼로 사용한다고 가정해 보겠습니다.
그런 다음 위와 같이 HLO를 작성하고 맞춤 호출의 출력에서 튜플 색인 1을 읽지 않습니다.