XLA 맞춤 호출

이 문서에서는 XLA 커스텀 호출을 작성하고 사용하는 방법을 설명합니다. 커스텀 호출을 사용하면 XLA 프로그램에서 C++ 또는 CUDA와 같은 프로그래밍 언어로 작성된 코드를 호출할 수 있습니다.

CPU의 커스텀 호출 만들기

XLA의 클라이언트 API를 통해 커스텀 호출을 나타내는 HLO 명령을 만들 수 있습니다. 예를 들어 다음 코드는 커스텀 호출을 사용하여 CPU의 A[i] = B[i % 128]+ C[i]를 계산합니다. (물론 그렇게 할 수 있습니다. 그리고 해야 합니다! 일반 HLO를 사용하면 됩니다.

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

do_custom_call 함수는 함수가 작동하는 버퍼의 크기를 알아야 합니다. 이 예에서는 1282048 크기를 하드코딩합니다. 이렇게 하지 않으려면 호출에 측정기준을 매개변수로 전달하면 됩니다.

GPU에서 커스텀 호출 만들기

GPU 맞춤 호출 프레임워크는 CPU의 프레임워크와 다소 다릅니다. 다음은 위의 CPU 코드와 동일한 계산 (A[i] = B[i % 128] + C[i])을 실행하는 CUDA 예시입니다.

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

GPU 맞춤 호출 함수는 여전히 CPU에서 실행되는 함수입니다. do_custom_call CPU 함수는 GPU의 작업을 큐에 추가하는 역할을 합니다. 여기서는 CUDA 커널을 실행하지만 cuBLAS 호출과 같은 다른 작업을 할 수도 있습니다.

buffers는 호스트에 있는 포인터의 배열이며 포함된 각 요소는 기기 (즉, GPU) 메모리를 가리킵니다. 매개변수가 먼저 오고 그 뒤에 출력 값이 옵니다. 이는 두 매개변수(insout)가 있는 CPU 호출 규칙과 특히 다릅니다. GPU 호출 규칙을 사용하면 튜플 모양의 입력/출력을 효율적으로 처리할 수 있습니다.

CPU 예에서와 같이 입력 및 출력 버퍼 사이즈를 맞춤 호출에 하드코딩했습니다. 그러나 CPU 사례와 달리 버퍼 크기를 맞춤 호출에 피연산자로 전달하는 것은 제대로 작동하지 않습니다. 일반적으로 CPU에서 사용할 수 있는 버퍼 크기가 필요합니다 (예: 커널을 실행할 때 사용할 블록/그리드 크기를 알아야 함). 하지만 버퍼 크기를 맞춤 호출에 피연산자로 전달하면 그 값은 GPU 메모리에 저장됩니다. 그런 다음 크기만 읽기 위해 작업을 시작할 때 비용이 많이 드는 동기식 기기 간 memcpy를 실행해야 합니다.

이 문제를 해결할 수 있도록 opaque 매개변수를 제공합니다. 맞춤 호출을 만들 때 이 값을 임의의 바이트 문자열로 설정할 수 있습니다.

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

xla::Shape에는 프로토콜 버퍼 표현이 있으므로 이 직렬화된 proto를 opaque 내부에 저장하고 GPU 맞춤 호출 내에서 역직렬화할 수 있습니다. 그러나 xla::ShapeProto는 자주 변경되지 않지만 변경됩니다. Git 로그를 확인하여 과거에 어떻게 변경되었는지 알아봅니다.

오류 신호

맞춤 호출에 오류가 발생하면 함수에 다음 서명을 사용하여 (출력 버퍼에서 비정상 종료나 말이 안 되는 것을 반환하는 대신) 오류를 XLA 런타임에 알릴 수 있습니다.

CPU의 경우:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

XlaCustomCallStatusSetFailure를 사용하여 실패를 알릴 수 있습니다.예를 들면 다음과 같습니다.

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

XlaCustomCallStatusSetSuccess를 사용하여 성공을 나타낼 수도 있지만 XlaCustomCallStatus는 기본적으로 성공 상태이므로 완전히 무시하는 경우에도 성공을 나타냅니다.

이 서명으로 커스텀 호출 함수를 사용할 때는 적절한 API 버전 집합으로 상응하는 custom-call 작업을 만들어야 합니다.예를 들면 다음과 같습니다.

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

실패하면 커스텀 호출 출력이 사용되지 않습니다. XLA 런타임에서 계산을 종료합니다. HLO 계산을 통해 오류로부터 복구할 수 없습니다 (예: 이를 포착하고 처리).

맞춤 호출에 튜플 전달

다음 맞춤 호출을 고려하세요.

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

CPU와 GPU 모두에서 튜플은 메모리에 포인터 배열로 표현됩니다. C++ 의사코드에서 위의 매개변수 0은 다음과 같이 표시됩니다.

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

튜플의 메모리 내 표현은 CPU와 GPU에서 동일하지만 CPU 및 GPU 커스텀 호출 호출 규칙에서는 다르게 처리됩니다.

임시 버퍼로서의 튜플 출력

맞춤 호출에 대한 튜플 입력은 편리하기는 하지만 반드시 필요한 것은 아닙니다. 커스텀 호출에 튜플 입력을 지원하지 않는 경우 커스텀 호출에 전달하기 전에 항상 get-tuple-element를 사용하여 튜플을 압축해제할 수 있습니다.

반면 출력 튜플을 사용하면 할 수 없었던 작업을 할 수 있습니다.

튜플 출력이 있어야 하는 분명한 이유는 튜플 출력이 커스텀 호출 (또는 기타 모든 XLA 작업)이 여러 독립 배열을 반환하는 방법이기 때문입니다.

그러나 덜 분명하지만 튜플 출력은 맞춤 호출 임시 메모리를 제공하는 방법이기도 합니다. 예. 출력은 임시 버퍼를 나타낼 수 있습니다. 출력 버퍼에는 작업이 여기에 쓸 수 있는 속성이 있으며, 쓰기가 완료된 후에 버퍼에서 읽을 수 있다는 점을 고려하세요. 이것이 바로 임시 버퍼에서 원하는 것입니다.

위의 예에서는 F32[1024]를 임시 버퍼로 사용한다고 가정합니다. 그런 다음 위와 마찬가지로 HLO를 작성하면 커스텀 호출 출력의 튜플 색인 1을 절대 읽지 않습니다.

CPU 커스텀 호출의 튜플

CPU 코드에는 do_custom_call(const void** ins, void* out) 함수가 있습니다. insparam0을 가리키는 요소가 하나만 있는 배열입니다. param0의 하위 버퍼는 포인터를 역참조하여 액세스할 수 있고 output_tuple의 하위 버퍼는 out를 역참조하여 액세스할 수 있습니다.

GPU 커스텀 호출의 튜플

GPU 코드에는 do_custom_call(..., void** buffers, ...) 함수가 있습니다. 이 경우 buffers는 입력/출력의 리프 버퍼마다 하나씩, 6 기기 포인터의 호스트 배열입니다. 평면 목록을 생성하기 위해 매개변수와 출력을 반복하고 각 매개변수에 관해 도형의 선주문 순회를 수행합니다. 구체적으로 설명:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1