XLA カスタム呼び出し

このドキュメントでは、XLA FFI ライブラリを使用して XLA カスタム呼び出しを作成して使用する方法について説明します。カスタム呼び出しは、コンパイル時に HLO モジュールの外部「オペレーション」を XLA コンパイラに記述するメカニズムです。XLA FFI は、実行時にこのようなオペレーションの実装を XLA に登録するメカニズムです。FFI は「他言語関数インターフェース」の略で、XLA が他のプログラミング言語で記述された外部コードを呼び出すためのバイナリ インターフェース(ABI)を定義する C API のセットです。XLA は、C++ で記述された XLA FFI のヘッダーのみのバインディングを提供します。これにより、基盤となる C API の低レベルの詳細がエンドユーザーから隠されます。

JAX + XLA カスタム呼び出し

カスタム呼び出しと XLA FFI を JAX と統合するエンドツーエンドの例については、JAX のドキュメントをご覧ください。

XLA FFI バインディング

XLA FFI バインディングは、カスタム呼び出しシグネチャのコンパイル時の仕様です。カスタム呼び出し引数、属性とその型、実行コンテキスト(GPU バックエンドの GPU ストリームなど)を介して渡される追加のパラメータが含まれます。XLA FFI バインディングは、互換性のある operator() シグネチャを持つ任意の C++ 呼び出し可能オブジェクト(関数ポインタ、ラムダなど)にバインドできます。構築されたハンドラは、XLA FFI 呼び出しフレーム(安定版 C API で定義)をデコードし、すべてのパラメータの型チェックを行い、デコードされた結果をユーザー定義のコールバックに転送します。

XLA FFI バインディングは、構築されたハンドラを最も効率的なマシンコードにコンパイルするために、テンプレート メタプログラミングに大きく依存しています。実行時のオーバーヘッドは、カスタム呼び出しパラメータごとに数ナノ秒のオーダーです。

テンプレートの特殊化として実装された XLA FFI カスタマイズ ポイント。ユーザーはカスタム型のデコード方法を定義できます。つまり、ユーザー定義の enum class 型のカスタム デコードを定義できます。

カスタム呼び出しからのエラーの返信

カスタム呼び出しの実装では、XLA ランタイムに成功またはエラーを通知するために xla::ffi::Error 値を返す必要があります。absl::Status と同様で、同じエラーコードのセットがあります。absl::Status は安定した ABI を持っておらず、動的に読み込まれたカスタム呼び出しライブラリと XLA 自体の間で渡すのは安全ではないため、使用しません。

// Handler that always returns an error.
auto always_error = Ffi::Bind().To(
    []() { return Error(ErrorCode::kInternal, "Oops!"); });

// Handler that always returns a success.
auto always_success = Ffi::Bind().To(
    []() { return Error::Success(); });

バッファ引数と結果

XLA は結果に宛先渡しスタイルを使用します。カスタム呼び出し(またはその他の XLA オペレーション)は結果にメモリを割り当てず、代わりに XLA ランタイムによって渡された宛先に書き込みます。XLA は静的バッファ割り当てを使用し、コンパイル時にライブ範囲に基づいてすべての値のバッファを割り当てます。

FFI ハンドラに渡される結果は、ポインタのようなセマンティクスを持つ Result<T> テンプレートにラップされます。operator-> は基盤となるパラメータへのアクセスを提供します。

AnyBuffer の引数と結果により、任意のデータ型のカスタム呼び出しバッファ パラメータにアクセスできます。これは、カスタムコールに複数のデータ型で動作する汎用実装があり、カスタムコール実装がデータ型に基づいて実行時ディスパッチを行う場合に便利です。AnyBuffer は、バッファ データ型、ディメンション、バッファ自体へのポインタへのアクセス権を付与します。

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  api_version = 4 : i32
} : (tensor<2x2xf32>) -> tensor<2x2xf32>
// Buffers of any number of dimensions and data type.
auto handler = Ffi::Bind().Arg<AnyBuffer>().Ret<AnyBuffer>().To(
    [](AnyBuffer arg, Result<AnyBuffer> res) -> Error {
      void* arg_data = arg.untyped_data();
      void* res_data = res->untyped_data();
      return Error::Success();
    });

制約付きバッファ引数と結果

Buffer を使用すると、バッファ データ型とディメンションの数に制約を追加できます。ランタイム引数が FFI ハンドラのシグネチャと一致しない場合、ハンドラによって自動的にチェックされ、XLA ランタイムにエラーが返されます。

// Buffers of any number of dimensions and F32 data type.
auto handler = Ffi::Bind().Arg<Buffer<F32>>().Ret<Buffer<F32>>().To(
    [](Buffer<F32> arg, Result<Buffer<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });
// Buffers of number of dimensions 2 and F32 data type.
auto handler = Ffi::Bind().Arg<BufferR2<F32>>().Ret<BufferR2<F32>>().To(
    [](BufferR2<F32> arg, Result<BufferR2<F32>> res) -> Error {
      float* arg_data = arg.typed_data();
      float* res_data = res->typed_data();
      return Error::Success();
    });

可変長引数と結果

カスタムコールのインスタンスごとに引数と結果の数が異なる場合は、RemainingArgsRemainingRets を使用して実行時にデコードできます。

auto handler = Ffi::Bind().RemainingArgs().RemainingRets().To(
    [](RemainingArgs args, RemainingRets results) -> Error {
      ErrorOr<AnyBuffer> arg = args.get<AnyBuffer>(0);
      ErrorOr<Result<AnyBuffer>> res = results.get<AnyBuffer>(0);

      if (!arg.has_value()) {
        return Error(ErrorCode::kInternal, arg.error());
      }

      if (!res.has_value()) {
        return Error(ErrorCode::kInternal, res.error());
      }

      return Error::Success();
    });

可変長引数と結果は、通常の引数と結果の後に宣言できますが、可変長引数の後に通常の引数と結果をバインドすることはできません。

auto handler =
    Ffi::Bind()
        .Arg<AnyBuffer>()
        .RemainingArgs()
        .Ret<AnyBuffer>()
        .RemainingRets()
        .To([](AnyBuffer arg, RemainingArgs args, AnyBuffer ret,
               RemainingRets results) -> Error { return Error::Success(); });

属性

XLA FFI は、custom_call backend_config として渡された mlir::DictionaryAttr を FFI ハンドラ引数に自動的にデコードすることをサポートしています。

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    i32 = 42 : i32,
    str = "string"
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

この例のカスタム呼び出しには、1 つのバッファ引数と 2 つの属性があります。XLA FFI はこれらを自動的にデコードして、ユーザー定義の呼び出し可能オブジェクトに渡すことができます。

auto handler = Ffi::Bind()
  .Arg<BufferR0<F32>>()
  .Attr<int32_t>("i32")
  .Attr<std::string_view>("str")
  .To([](BufferR0<F32> buffer, int32_t i32, std::string_view str) {
    return Error::Success();
  });

ユーザー定義の列挙型属性

XLA FFI は、整数 MLIR 属性をユーザー定義の列挙型に自動的にデコードできます。列挙型クラスは同じ基盤となる整数型を持つ必要があり、デコードは XLA FFI に明示的に登録する必要があります。

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    command = 0 : i32
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>
enum class Command : int32_t {
  kAdd = 0,
  kMul = 1,
};

XLA_FFI_REGISTER_ENUM_ATTR_DECODING(Command);

auto handler = Ffi::Bind().Attr<Command>("command").To(
    [](Command command) -> Error { return Error::Success(); });

すべてのカスタム通話属性をバインドする

すべてのカスタム通話属性を辞書として取得し、実行時に必要な属性のみを遅延デコードできます。

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<int32_t> i32 = attrs.get<int32_t>("i32");
  return Error::Success();
});

ユーザー定義の構造体属性

XLA FFI は、辞書属性をユーザー定義の構造体にデコードできます。

%0 = "stablehlo.custom_call"(%arg0) {
  call_target_name = "foo",
  backend_config= {
    range = { lo = 0 : i64, hi = 42 : i64 }
  },
  api_version = 4 : i32
} : (tensor<f32>) -> tensor<f32>

上記の例では、rangemlir::DictionaryAttr 属性です。名前で辞書フィールドにアクセスする代わりに、C++ 構造体として自動的にデコードできます。デコードは XLA_FFI_REGISTER_STRUCT_ATTR_DECODING マクロで明示的に登録する必要があります(内部的には ::xla::ffi 名前空間でテンプレートの特殊化を定義するため、マクロをグローバル名前空間に追加する必要があります)。

struct Range {
  int64_t lo;
  int64_t hi;
};

XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(Range, StructMember<int64_t>("lo"),
                                             StructMember<int64_t>("hi"));

auto handler = Ffi::Bind().Attr<Range>("range").To([](Range range) -> Error{
  return Error::Success();
});

カスタム属性は、他の属性と同様に、辞書から読み込むことができます。次の例では、すべてのカスタム通話属性が Dictionary としてデコードされ、range に名前でアクセスできます。

auto handler = Ffi::Bind().Attrs().To([](Dictionary attrs) -> Error {
  ErrorOr<Range> range = attrs.get<Range>("range");
  return Error::Success();
});

CPU でカスタム呼び出しを作成する

XLA のクライアント API を介してカスタム呼び出しを表す HLO 命令を作成できます。たとえば、次のコードでは、カスタム呼び出しを使用して CPU で A[i] = B[i % 128]+ C[i] を計算しています。(もちろん、そうすることもできますし、そうすべきです。– 通常の HLO でこれを行います)。

#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"

void do_it() {
  xla::XlaBuilder b("do_it");
  xla::XlaOp param0 =
      xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
  xla::XlaOp param1 =
      xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
  xla::XlaOp custom_call =
      xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
        /*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
        /*opaque=*/"", /*has_side_effect=*/false,
        /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
        /*schedule=*/CustomCallSchedule::SCHEDULE_NONE,
        /*api_version=*/CustomCallApiVersion::API_VERSION_TYPED_FFI);
}

// Constrain custom call arguments to 1-dimensional buffers of F32 data type.
using BufferF32 = xla::ffi::BufferR1<xla::ffi::DataType::F32>;

// Implement a custom call as a C++ function. Note that we can use `Buffer` type
// defined by XLA FFI that gives us access to buffer data type and shape.
xla::ffi::Error do_custom_call(BufferF32 in0, BufferF32 in1,
                               xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];

  // Check that dimensions are compatible.
  assert(out->dimensions[0] == d1 && "unexpected dimensions");

  for (size_t i = 0; i < d1; ++i) {
    out->data[i] = in0.data[i % d0] + in1.data[i];
  }
}

// Explicitly define an XLA FFI handler signature and bind it to the
// `do_custom_call` implementation. XLA FFI handler can automatically infer
// type signature from the custom call function, but it relies on magical
// template metaprogramming an explicit binding provides and extra level of
// type checking and clearly states custom call author intentions.
XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Arg<Buffer>()
                           .Arg<Buffer>()
                           .Ret<Buffer>());

// Registers `handler` with and XLA FFI on a "Host" platform.
XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "Host", handler);

GPU でカスタム呼び出しを作成する

XLA FFI を使用した GPU カスタム呼び出しの登録はほぼ同じです。唯一の違いは、GPU の場合、デバイスでカーネルを起動できるように、基盤となるプラットフォーム ストリーム(CUDA または ROCM ストリーム)をリクエストする必要があることです。上記の CPU コードと同じ計算(A[i] = B[i % 128] + C[i])を行う CUDA の例を次に示します。

void do_it() { /* same implementation as above */ }

__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
  out[idx] = in0[idx % 128] + in1[idx];
}

void do_custom_call(CUstream stream, BufferF32 in0, BufferF32 in1,
                    xla::ffi::Result<BufferF32> out) {
  size_t d0 = in0.dimensions[0];
  size_t d1 = in1.dimensions[0];
  size_t d2 = out->dimensions[0];

  assert(d0 == 128 && d1 == 2048 && d2 == 2048 && "unexpected dimensions");

  const int64_t block_dim = 64;
  const int64_t grid_dim = 2048 / block_dim;
  custom_call_kernel<<<grid_dim, block_dim, 0, stream>>>(
    in0.data, in1.data, out->data);
}

XLA_FFI_DEFINE_HANDLER(handler, do_custom_call,
                       ffi::Ffi::Bind()
                           .Ctx<xla::ffi::PlatformStream<CUstream>>()
                           .Arg<BufferF32>()
                           .Arg<BufferF32>()
                           .Ret<BufferF32>());

XLA_FFI_REGISTER_HANDLER(xla::ffi::GetXlaFfiApi(), "do_custom_call",
                         "CUDA", handler);

まず、GPU カスタム呼び出し関数は依然として CPU で実行される関数であることに注意してください。do_custom_call CPU 関数は、GPU での作業のキューイングを担当します。ここでは CUDA カーネルを起動していますが、cuBLAS を呼び出すなど、他の処理を行うこともできます。

引数と結果もホストに存在し、データ メンバーにはデバイス(GPU)メモリへのポインタが含まれます。カスタム呼び出しハンドラに渡されるバッファは、基盤となるデバイス バッファの形状をしているため、カスタム呼び出しはそれらからカーネル起動パラメータを計算できます。

カスタム呼び出しへのタプルの受け渡し

次のカスタム呼び出しについて考えてみましょう。

using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
    ShapeUtil::MakeShape(F32, {32}),
    ShapeUtil::MakeTuple({
        ShapeUtil::MakeShape(F32, {64}),
        ShapeUtil::MakeShape(F32, {128}),
    }),
    ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");

Shape out_shape = ShapeUtil::MakeTuple({
  ShapeUtil::MakeShape(F32, {512}),
  ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape, ...);

CPU と GPU の両方で、タプルはメモリ内でポインタの配列として表されます。XLA がタプル引数または結果を含むカスタム呼び出しを呼び出すと、それらがフラット化され、通常のバッファ引数または結果として渡されます。

タプルを一時バッファとして出力する

カスタム呼び出しへのタプル入力は便利ですが、厳密には必須ではありません。カスタム呼び出しへのタプル入力がサポートされていない場合でも、get-tuple-element を使用してタプルをアンパックしてから、カスタム呼び出しに渡すことができます。

一方、タプルの出力を使用すると、他の方法ではできない処理を実行できます。

タプル出力を使用する明らかな理由は、カスタム呼び出し(または他の XLA オペレーション)が複数の独立した配列を返す方法がタプル出力であるためです。

あまり知られていませんが、タプル出力はカスタム呼び出しの一時メモリを提供する方法でもあります。はい。出力は一時バッファを表すことができます。出力バッファには、op が書き込み可能で、書き込み後に読み取り可能であるというプロパティがあります。これは、一時バッファに求められるまさにその機能です。

上記の例で、F32[1024] を一時バッファとして使用するとします。この場合、HLO は上記のように記述し、カスタム呼び出しの出力のタプル インデックス 1 を読み取らないようにします。