XLA カスタム呼び出し

このドキュメントでは、XLA カスタム呼び出しを記述して使用する方法について説明します。カスタム呼び出しを使用すると、C++ や CUDA などのプログラミング言語で記述されたコードを XLA プログラムから呼び出せます。

CPU でカスタム呼び出しを作成する

XLA のクライアント API を介したカスタム呼び出しを表す HLO 命令を作成できます。たとえば、次のコードは、カスタム呼び出しを使用して CPU で A[i] = B[i % 128]+ C[i] を計算します。(もちろん可能です。また、そのようにすべきです。- 通常の HLO で行うことができます)。

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                      /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}

void do_custom_call(void* out, const void** in) {
  float* out_buf = reinterpret_cast<float*>(out);
  const float* in0 = reinterpret_cast<const float*>(in[0]);
  const float* in1 = reinterpret_cast<const float*>(in[1]);
  for (int i = 0; i < 2048; ++i) {
    out_buf[i] = in0[i % 128] + in1[i];
  }
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");

関数 do_custom_call は、演算対象となるバッファのディメンションを知る必要があります。この例では、サイズ 1282048 をハードコードしています。この処理を行わない場合は、ディメンションをパラメータとして呼び出しに渡すことができます。

GPU でカスタム呼び出しを作成する

GPU カスタム呼び出しフレームワークは、CPU のフレームワークとは多少異なります。上記の CPU コードと同じ計算(A[i] = B[i % 128] + C[i])を行う CUDA の例を以下に示します。

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, void** buffers,
                    const char* opaque, size_t opaque_len) {
  const float* in0 = reinterpret_cast<const float*>(buffers[0]);
  const float* in1 = reinterpret_cast<const float*>(buffers[1]);
  float* out = reinterpret_cast<float*>(buffers[2]);

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim,
                       /*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");

まず、GPU カスタム呼び出し関数は、CPU 上で実行される関数であることに留意してください。do_custom_call CPU 関数は、GPU の作業をキューに登録します。ここでは CUDA カーネルを起動しますが、cuBLAS の呼び出しなど、他のこともできます。

buffers は、ホスト上に存在するポインタの配列です。各要素には、デバイス(GPU)メモリを指すポイントが含まれています。パラメータが最初に来て、その後に出力値が続きます。この点は、insout の 2 つのパラメータがある CPU 呼び出しの規則とは大きく異なります。GPU 呼び出しの規則により、タプル型の入出力を効率的に処理できます。

CPU の例と同様に、入力バッファと出力バッファのサイズをカスタム呼び出しにハードコードしています。ただし、CPU の場合とは異なり、バッファサイズをオペランドとしてカスタム呼び出しに渡すとうまく機能しません。通常、CPU 上で利用可能なバッファサイズが必要になります(たとえば、カーネルを起動する場合、使用するブロック/グリッドの寸法を把握する必要があります)。しかし、バッファサイズをオペランドとしてカスタム呼び出しに渡すと、その値は GPU メモリに格納されます。この場合、サイズを読み取るためだけに、オペレーションの開始時に高価なデバイス間の同期 memcpy を実行する必要があります。

この問題を回避するために、opaque パラメータを提供しています。カスタム呼び出しを作成するときに、この値を任意のバイト文字列に設定できます。

std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
                opaque);

xla::Shape にはプロトコル バッファ表現があるため、このシリアル化された proto を opaque 内に格納し、GPU カスタム呼び出し内でシリアル化解除できます。ただし、xla::ShapeProto は頻繁には変更されませんが、変更されます。Git ログで、過去にどう変更されているかを確認します。

エラーのシグナリング

カスタム呼び出しでエラーが発生した場合は、関数に次のシグネチャを使用して、エラーを XLA ランタイムに通知できます(たとえば、出力バッファでクラッシュしたり意味不明を返したりする必要はありません)。

CPU 上:

#include "xla/service/custom_call_status.h"

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);

GPU:

#include "xla/service/custom_call_status.h"

void do_custom_call(CUstream stream, void** buffers, const char* opaque,
                    size_t opaque_len, xla::XlaCustomCallStatus* status);

XlaCustomCallStatusSetFailure を使用してエラーを通知できます。次に例を示します。

void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
  // ... do some work.

  if (bad_condition) {
    char* error_message = "An error occurred";
    XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
    return;
  }

  // ... continue.
}

XlaCustomCallStatusSetSuccess を使用して成功を示すこともできますが、XlaCustomCallStatus はデフォルトで成功状態であるため、完全に無視しても成功を示します。

このシグネチャでカスタム呼び出し関数を使用する場合は、対応する custom-call 演算を適切な API バージョン セットで作成する必要があります。次に例を示します。

xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
                /*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
                opaque, /*has_side_effect=*/false,
                /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
                /*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
                /*api_version=*/API_VERSION_STATUS_RETURNING);

失敗した場合、カスタム呼び出し出力は使用されず、XLA ランタイムが計算を終了します。HLO 計算をエラーから回復させることはできません(たとえば、エラーをキャッチして処理することはできません)。

タプルをカスタム呼び出しに渡す

次のカスタム呼び出しについて考えてみましょう。

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);

CPU と GPU の両方で、タプルはメモリ内でポインタの配列として表されます。C++ 擬似コードで上記のパラメータ 0 は次のように配置されます。

// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];

void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;

void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;

タプルのメモリ内表現は CPU と GPU で同じですが、CPU と GPU のカスタムコール呼び出し規則では処理が異なります。

一時バッファとしてのタプル出力

カスタム呼び出しへのタプル入力は便利ですが、厳密に必要なわけではありません。カスタム呼び出しへのタプル入力がサポートされていない場合は、カスタム呼び出しに渡す前に、get-tuple-element を使用してタプルをアンパックできます。

一方、タプル出力を使用すると、できないことが実現されます。

タプル出力がある明らかな理由は、タプル出力は、カスタム呼び出し(または他の XLA 演算)が複数独立した配列を返す方法です。

言うまでもなく、タプル出力はカスタム呼び出し一時メモリを提供する手段でもあります。はい。出力は一時バッファを表すことができます。出力バッファに op が書き込めるプロパティがあり、書き込み後にそこから読み取ることができるとします。これが一時バッファに求めるものです。

上記の例では、F32[1024] を一時バッファとして使用するとします。この場合、上記のように HLO を記述し、カスタム呼び出しの出力のタプル インデックス 1 は決して読み取ることはありません。

CPU カスタム呼び出しのタプル

CPU コードには、関数 do_custom_call(const void** ins, void* out) があります。ins は要素を 1 つだけ含む配列で、param0 を指します。param0 のサブバッファには、そのポインタを逆参照することによってアクセスでき、output_tuple のサブバッファには out を逆参照することによってアクセスできます。

GPU カスタム呼び出しのタプル

GPU コードには、関数 do_custom_call(..., void** buffers, ...) があります。この場合、buffers は 6 つのデバイス ポインタ(入出力のリーフバッファごとに 1 つずつ)のホスト配列です。フラットリストを生成するには、パラメータと出力を反復処理し、それぞれに対してその形状の事前注文走査を行います。具体的に:

// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1