XLA カスタム呼び出し

このドキュメントでは、XLA FFI ライブラリを使用して XLA カスタム呼び出しを作成し、使用する方法について説明します。カスタム呼び出しは、コンパイル時に HLO モジュール内の外部「演算」を XLA コンパイラに記述するメカニズムです。XLA FFI は、そのような演算の実装を(実行時に)XLA に登録するメカニズムです。FFI は「外部関数インターフェース」の略で、他のプログラミング言語で記述された外部コードを呼び出すために XLA 用のバイナリ インターフェース(ABI)を定義する C API のセットです。XLA は、C++ で記述された XLA FFI にヘッダーのみのバインディングを提供します。これにより、基盤となる C API の低レベルの詳細情報がエンドユーザーから見えなくなります。

CPU のカスタム呼び出しを作成する

XLA のクライアント API を介したカスタム呼び出しを表す HLO 命令を作成できます。たとえば次のコードは、カスタム呼び出しを使用して CPU の A[i] = B[i % 128]+ C[i] を計算します。(もちろん可能です。そのようにすべきです。(これは通常の HLO で行います)。

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to rank-1 buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C+ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

GPU でカスタム呼び出しを作成する

XLA FFI による GPU カスタム呼び出しの登録はほぼ同じです。唯一の違いは、GPU の場合、デバイス上でカーネルを起動するために基盤となるプラットフォーム ストリーム(CUDA または ROCM ストリーム)を要求する必要があることです。上記の CPU コードと同じ計算(A[i] = B[i % 128] + C[i])を行う CUDA の例を次に示します。

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

まず、GPU カスタム呼び出し関数は、引き続き CPU で実行される関数であることに注意してください。do_custom_call CPU 関数は、GPU の作業をキューに登録する役割を担います。ここでは CUDA カーネルを起動しますが、cuBLAS を呼び出すなど、他の処理を行うこともできます。

引数と結果もホスト上にあり、データメンバーにはデバイス(GPU)メモリへのポインタが含まれます。カスタム呼び出しハンドラに渡されるバッファは、基盤となるデバイス バッファの形状を持つため、カスタム呼び出しはそこからカーネルの起動パラメータを計算できます。

タプルをカスタム呼び出しに渡す

次のカスタム呼び出しについて考えてみましょう。

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

CPU と GPU の両方で、タプルはメモリ内でポインタの配列として表されます。XLA は、タプル引数または結果を使用してカスタム呼び出しを呼び出すと、それらをフラット化し、通常のバッファ引数または結果として渡します。

一時バッファとしてのタプル出力

カスタム呼び出しへのタプル入力は便利ですが、厳密に必要なわけではありません。カスタム呼び出しへのタプル入力がサポートされていない場合は、カスタム呼び出しに渡す前に、get-tuple-element を使用してタプルを展開できます。

一方、タプルの出力を使用すれば、他の方法ではできないことができるようになります。

タプル出力を使用する明らかな理由は、タプル出力は、カスタム呼び出し(または他の XLA 演算)で複数の独立した配列を返す方法です。

言うまでもなく、タプル出力はカスタム呼び出しの一時メモリを提供する手段でもあります。はい。出力で一時バッファを表すことができます。出力バッファに op が書き込めるプロパティがあり、書き込み後にそこから読み取ることができるとします。これが、一時バッファに求めることです。

上記の例では、F32[1024] を一時バッファとして使用するとします。次に、上記のように HLO を記述します。カスタム呼び出しの出力のタプル インデックス 1 は決して読み取ることはありません。