本文件說明如何編寫及使用 XLA 自訂呼叫。自訂呼叫可讓您透過 XLA 程式叫用以 C++ 或 CUDA 等程式設計語言編寫的程式碼。
在 CPU 上建立自訂呼叫
您可以透過 XLA 的用戶端 API 建立 HLO 指令來表示自訂呼叫。例如,以下程式碼使用自訂呼叫,在 CPU 上計算 A[i] = B[i %
128]+ C[i]
。(當然可以,而且應該這麼做!- 使用一般 HLO 執行此功能)
#include "xla/client/xla_builder.h"
#include "xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}
void do_custom_call(void* out, const void** in) {
float* out_buf = reinterpret_cast<float*>(out);
const float* in0 = reinterpret_cast<const float*>(in[0]);
const float* in1 = reinterpret_cast<const float*>(in[1]);
for (int i = 0; i < 2048; ++i) {
out_buf[i] = in0[i % 128] + in1[i];
}
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");
請注意,do_custom_call
函式需要知道其運作的緩衝區尺寸。在這個範例中,我們對大小 128
和 2048
進行硬式編碼。如果不想執行這項作業,您可以將維度做為參數傳入呼叫。
在 GPU 上建立自訂呼叫
GPU 自訂呼叫架構與 CPU 上略有不同。以下 CUDA 範例執行的運算與上述 CPU 程式碼相同 (A[i] = B[i % 128] + C[i]
)。
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, void** buffers,
const char* opaque, size_t opaque_len) {
const float* in0 = reinterpret_cast<const float*>(buffers[0]);
const float* in1 = reinterpret_cast<const float*>(buffers[1]);
float* out = reinterpret_cast<float*>(buffers[2]);
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim,
/*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");
請注意,GPU 自訂呼叫函式仍是在 CPU 上執行的函式,do_custom_call
CPU 函式負責將 GPU 上的工作排入佇列。此工具會啟動 CUDA 核心,但也可以執行其他操作,例如呼叫 cuBLAS。
buffers
是位於主機上的指標陣列,其中包含指向裝置 (即 GPU) 記憶體的每個元素。參數會先依序列出,接著是輸出值。這與 CPU 呼叫慣例不同,後者有兩個參數:ins
和 out
。GPU 呼叫慣例可讓您有效率地處理元組的輸入/輸出內容。
與 CPU 範例一樣,我們已將輸入和輸出緩衝區空間硬式編碼為自訂呼叫。然而,與 CPU 情況不同的是,將緩衝區大小做為運算元傳遞至自訂呼叫將無法正常運作。我們通常會需要 CPU 上的緩衝區大小 (例如,啟動核心時,必須知道要使用的區塊/格線維度)。不過,如果我們將緩衝區大小做為運算元傳遞給自訂呼叫,這些緩衝區值會保存在 GPU 記憶體中。接著,我們需要在作業開始時執行昂貴的同步裝置對代管 memcpy
,以便讀取大小。
為協助您解決這個問題,我們提供 opaque
參數。您可以在建立自訂呼叫時,將此設為任意位元組字串:
std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
opaque);
由於 xla::Shape
有通訊協定緩衝區表示法,因此您可以將這個序列化的 proto 儲存在 opaque
中,並在 GPU 自訂呼叫內去序列化。不過請注意,雖然 xla::ShapeProto
不會經常變更,但「確實」會變更。請查看 Git 記錄,瞭解過去的異動。
指出錯誤
如果自訂呼叫遇到錯誤,您可以在函式中使用下列簽名,將錯誤信號傳送至 XLA 執行階段 (而非在輸出緩衝區中異常終止或傳回不合理):
在 CPU 上:
#include "xla/service/custom_call_status.h"
void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);
使用 GPU:
#include "xla/service/custom_call_status.h"
void do_custom_call(CUstream stream, void** buffers, const char* opaque,
size_t opaque_len, xla::XlaCustomCallStatus* status);
您可以使用 XlaCustomCallStatusSetFailure
表示失敗,例如:
void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
// ... do some work.
if (bad_condition) {
char* error_message = "An error occurred";
XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
return;
}
// ... continue.
}
您也可以使用 XlaCustomCallStatusSetSuccess
表示成功,但 XlaCustomCallStatus
預設會處於成功狀態,因此完全忽略也可以表示成功。
搭配這個簽章使用自訂呼叫函式時,您必須建立適當的 custom-call
運算,並提供適當的 API 版本集,例如:
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
opaque, /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/API_VERSION_STATUS_RETURNING);
失敗時,系統不會使用任何自訂呼叫輸出;XLA 執行階段將終止計算。HLO 運算無法從錯誤中復原 (例如擷取及處理錯誤)。
將元組傳送至自訂呼叫
請考慮使用下列自訂呼叫。
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);
在 CPU 和 GPU 上,元組會以指標陣列的形式在記憶體中表示。在 C++ 虛擬程式碼中,上方的參數 0 呈現如下。
// In-memory layout of parameter 0 from custom call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];
void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;
void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;
雖然 CPU 和 GPU 在記憶體內配置的表示法相同,但 CPU 和 GPU 自訂呼叫呼叫慣例的處理方式不同。
以臨時緩衝區形式輸出的元組輸出內容
自訂呼叫的元組輸入很方便,但其實沒有必要性。如果我們不支援自訂呼叫元組輸入內容,您可以先使用 get-tuple-element 將元組解壓縮,再將其傳送至自訂呼叫。
另一方面,元組的「輸出」可讓您執行原本無法執行的操作。
出現元組輸出的明顯原因是,元組輸出是自訂呼叫 (或任何其他 XLA 運算) 傳回多個獨立陣列的方式。
但較不明顯,元組輸出也是提供自訂呼叫暫存記憶體的方法。可以,「輸出」可以代表暫存緩衝區。請考慮到輸出緩衝區,其中含有運算可寫入的屬性,且可在寫入後從輸出緩衝區讀取。這就是您需要的臨時緩衝區。
在上述範例中,假設我們想使用 F32[1024]
做為暫存緩衝區。接下來,我們會如上所示寫入 HLO,而我們絕不會讀取自訂呼叫輸出內容的元組索引 1。
CPU 自訂呼叫中的元組
在 CPU 程式碼中,我們有 do_custom_call(const void** ins, void* out)
函式。ins
是一個陣列,只有一個元素,而該元素指向 param0
。可解開該指標即可存取 param0
的子緩衝區,而宣告 out
可存取 output_tuple
的子緩衝區。
GPU 自訂呼叫中的元組
在 GPU 程式碼中,我們有 do_custom_call(..., void** buffers, ...)
函式。在此情況下,buffers
是 6 裝置指標的主機陣列,其中每個分葉緩衝區在輸入/輸出中各有一個。為產生平面清單,我們會反覆處理參數和輸出內容,並且針對每個參數進行預購。具體做法:
// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1