XLA 커스텀 호출

이 문서에서는 XLA FFI 라이브러리를 사용하여 XLA 커스텀 호출을 작성하고 사용하는 방법을 설명합니다. 맞춤 호출은 HLO 모듈의 외부 '작업'을 컴파일 시점에 XLA 컴파일러에 설명하는 메커니즘이며, XLA FFI는 런타임에 XLA로 이러한 작업의 구현을 등록하는 메커니즘입니다. FFI는 '외래 함수 인터페이스'를 의미하며, XLA가 다른 프로그래밍 언어로 작성된 외부 코드를 호출할 수 있도록 바이너리 인터페이스 (ABI)를 정의하는 C API 집합입니다. XLA는 C++로 작성된 XLA FFI의 헤더 전용 바인딩을 제공하며, 이 바인딩은 기본 C API의 모든 하위 수준 세부정보를 최종 사용자에게 숨깁니다.

CPU에서 맞춤 호출 만들기

XLA의 클라이언트 API를 통해 커스텀 호출을 나타내는 HLO 명령을 만들 수 있습니다. 예를 들어 다음 코드는 커스텀 호출을 사용하여 CPU의 A[i] = B[i % 128]+ C[i]를 계산합니다. (물론 그렇게 할 수 있어야 합니다. 일반 HLO를 사용하면 됩니다.)

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla
::XlaBuilder b("do_it");
  xla
::XlaOp param0 =
      xla
::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla
::XlaOp param1 =
      xla
::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla
::XlaOp custom_call =
      xla
::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
       
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
       
/*opaque=*/"", /*has_side_effect=*/false,
       
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
       
/*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
       
/*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla
::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla
::ffi::Result<BufferF32> out) {
  size_t d0
= in0.dimensions[0];
  size_t d1
= in1.dimensions[0];

 
// Check that dimensions are compatible.
 
assert(out->dimensions[0] == d1 && "unexpected dimensions");

 
for (size_t i = 0; i < d1; ++i) {
   
out->data[i] = in0.data[i % d0] + in1.data[i];
 
}
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER
(handler, do_custom_call,
                       ffi
::Ffi::Bind()
                           
.Arg<Buffer>()
                           
.Arg<Buffer>()
                           
.Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER
(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         
"Host", handler);

GPU에서 커스텀 호출 만들기

XLA FFI를 사용한 GPU 커스텀 호출 등록은 거의 동일합니다. 유일한 차이점은 GPU의 경우 기기에서 커널을 실행할 수 있도록 기본 플랫폼 스트림(CUDA 또는 ROCM 스트림)을 요청해야 한다는 것입니다. 다음은 위의 CPU 코드와 동일한 계산 (A[i] = B[i % 128] + C[i])을 실행하는 CUDA 예입니다.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel
(const float* in0, const float* in1, float* out) {
  size_t idx
= blockIdx.x * blockDim.x + threadIdx.x;
 
out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla
::ffi::Result<BufferF32> out) {
  size_t d0
= in0.dimensions[0];
  size_t d1
= in1.dimensions[0];
  size_t d2
= out->dimensions[0];

 
assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

 
const int64_t block_dim = 64;
 
const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel
<<<grid_dim, block_dim, 0, stream>>>(
    in0
.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER
(handler, do_custom_call,
                       ffi
::Ffi::Bind()
                           
.Ctx<xla::ffi::PlatformStream<CUstream>>()
                           
.Arg<BufferF32>()
                           
.Arg<BufferF32>()
                           
.Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER
(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         
"CUDA", handler);

GPU 맞춤 호출 함수는 여전히 CPU에서 실행되는 함수입니다. do_custom_call CPU 함수는 GPU의 작업을 큐에 추가합니다. 여기서는 CUDA 커널을 실행하지만 cuBLAS 호출과 같은 다른 작업도 할 수 있습니다.

인수와 결과도 호스트에 상주하며 데이터 멤버에는 기기 (즉, GPU) 메모리를 가리키는 포인터가 포함됩니다. 맞춤 호출 핸들러에 전달되는 버퍼는 기본 기기 버퍼의 형태를 취하므로 맞춤 호출은 버퍼에서 커널 실행 매개변수를 계산할 수 있습니다.

튜플을 커스텀 호출에 전달

다음과 같은 맞춤 호출을 고려해 보세요.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
   
ShapeUtil::MakeShape(F32, {32}),
   
ShapeUtil::MakeTuple({
       
ShapeUtil::MakeShape(F32, {64}),
       
ShapeUtil::MakeShape(F32, {128}),
   
}),
   
ShapeUtil::MakeShape(F32, {256}),
});
xla
::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
 
ShapeUtil::MakeShape(F32, {512}),
 
ShapeUtil::MakeShape(F32, {1024}),
});
xla
::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

CPU와 GPU 모두에서 튜플은 메모리에 포인터 배열로 표현됩니다. XLA는 튜플 인수 또는 결과를 사용하여 맞춤 호출을 호출할 때 이를 평면화하고 일반 버퍼 인수 또는 결과로 전달합니다.

임시 버퍼로서의 튜플 출력

커스텀 호출에 대한 튜플 입력은 편리하지만 반드시 필요한 것은 아닙니다. 맞춤 호출에 튜플 입력을 지원하지 않는 경우 맞춤 호출에 전달하기 전에 항상 get-tuple-element를 사용하여 튜플을 압축해제할 수 있습니다.

반면 outputs 튜플은 다른 방법으로는 할 수 없었던 작업을 가능하게 합니다.

튜플 출력이 있어야 하는 명백한 이유는 튜플 출력이 맞춤 호출 (또는 기타 모든 XLA 작업)이 여러 독립 배열을 반환하는 방식 때문입니다.

그러나 튜플 출력은 맞춤 호출 임시 메모리를 제공하는 방법이기도 합니다. 예. 출력은 임시 버퍼를 나타낼 수 있습니다. 출력 버퍼에는 작업이 여기에 쓸 수 있는 속성이 있으며, 쓰기가 완료된 후에 버퍼에서 읽을 수 있습니다. 이것이 바로 임시 버퍼에서 원하는 것입니다.

위의 예에서는 F32[1024]를 임시 버퍼로 사용한다고 가정합니다. 그런 다음 위와 같이 HLO를 작성하면 커스텀 호출 출력의 튜플 색인 1을 절대 읽지 않습니다.