本文档介绍如何编写和使用 XLA“自定义调用”。借助自定义调用,您可以从 XLA 程序调用以 C++ 或 CUDA 等编程语言编写的代码。
警告:自定义调用是一种底层高级用户功能。使用自定义调用时,很容易使用难以调试(甚至难以通知)的方式破坏程序。除非您准备在发生问题时自行调试 XLA,否则不应使用自定义调用,而且在陷入困境时,您应当预计到从 XLA 开发者那里得到的帮助会相对较少。
警告:自定义调用 API/ABI 目前不稳定。我们不打算随意更改它,但它也可能发生变化。下文介绍了一些未来可能进行的变更。
CPU 上的自定义调用
您可以通过 XLA 的客户端 API 创建代表自定义调用的 HLO 指令。在撰写本文时,尚未通过 TensorFlow 公开此功能。
例如,以下代码使用自定义调用来计算 `A[i] = B[i % 128]
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
void do_it() {
xla::XlaBuilder b("do_it");
xla::XlaOp param0 =
xla::Parameter(&b, 0, xla::ShapeUtil::MakeShape(xla::F32, {128}), "p0");
xla::XlaOp param1 =
xla::Parameter(&b, 1, xla::ShapeUtil::MakeShape(xla::F32, {2048}), "p1");
xla::XlaOp custom_call =
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}));
}
void do_custom_call(void* out, const void** in) {
float* out_buf = reinterpret_cast<float*>(out);
const float* in0 = reinterpret_cast<const float*>(in[0]);
const float* in1 = reinterpret_cast<const float*>(in[1]);
for (int i = 0; i < 2048; ++i) {
out_buf[i] = in0[i % 128] + in1[i];
}
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "Host");
请注意,函数 do_custom_call
需要了解其运行所处的缓冲区的大小。在此示例中,我们对大小 128 和 2048 进行硬编码。如果您不想这样做,则可以将大小作为参数传递给调用。
GPU 上的自定义调用
GPU 自定义调用框架与 CPU 上的框架有所不同。下面是一个 CUDA 示例,它执行与上述 CPU 代码相同的 A[i] = B[i % 128] + C[i]
计算。
void do_it() { /* same implementation as above */ }
__global__ custom_call_kernel(const float* in0, const float* in1, float* out) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
out[idx] = in0[idx % 128] + in1[idx];
}
void do_custom_call(CUstream stream, void** buffers,
const char* opaque, size_t opaque_len) {
const float* in0 = reinterpret_cast<const float*>(buffers[0]);
const float* in1 = reinterpret_cast<const float*>(buffers[1]);
float* out = reinterpret_cast<float*>(buffers[2]);
const int64_t block_dim = 64;
const int64_t grid_dim = 2048 / block_dim;
custom_call_kernel<<<grid_dim, block_dim,
/*dynamic_shared_mem_bytes=*/0, stream>>>(in0, in1, out);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(do_custom_call, "CUDA");
首先请注意,GPU 自定义调用函数仍然是在 CPU 上执行的函数。我们的 do_custom_call
CPU 函数负责将 GPU 上的工作加入队列。在这里,它会启动 CUDA 内核,但也可以执行其他操作,例如调用 cublas。
此功能仅适用于 GPU 自定义调用;无法为 CPU 自定义调用发出失败信号。
与 CPU 示例中一样,我们已将输入和输出缓冲区的大小硬编码到我们的自定义调用中。但是,与 CPU 情况不同,将缓冲区大小作为运算对象传递给自定义调用将无法正常工作。通常,我们需要 CPU 上可用的缓冲区大小;例如,在启动内核时,我们需要了解要使用的块/网格大小。但是,如果我们将缓冲区大小作为运算对象传递给自定义调用,它们的值将驻留在 GPU 内存中。随后,我们必须在运算开始时执行一个消耗大量资源的同步设备-主机 memcpy 来读取大小。
为了解决此问题,我们提供了 opaque
参数。创建自定义调用时,可以将此参数设置为任意字节串:
std::string opaque = "...";
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*output_shape=*/xla::ShapeUtil::MakeShape(xla::F32, {2048}),
opaque);
由于 xla::Shape
具有协议缓冲区表示,因此您可以将此序列化 proto 存储在 opaque
中,并在 GPU 自定义调用中将其反序列化。不过请注意,尽管 xla::ShapeProto
不会经常变化,但有时确实会变化。检查 git 日志以查看其过去发生的变化。
发出错误信号。
如果您的自定义调用遇到错误,您可以通过在 CPU 上为函数使用以下签名来向 XLA 运行时发出错误信号(而不是崩溃或在输出缓冲区中返回废话):
#include "tensorflow/compiler/xla/service/custom_call_status.h"
void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status);
… 在 GPU 上:
#include "tensorflow/compiler/xla/service/custom_call_status.h"
void do_custom_call(CUstream stream, void** buffers, const char* opaque,
size_t opaque_len, xla::XlaCustomCallStatus* status);
您可以使用 XlaCustomCallStatusSetFailure
发出失败信号,例如:
void do_custom_call(void* out, const void** in, XlaCustomCallStatus* status) {
// ... do some work.
if (bad_condition) {
char* error_message = "An error occurred";
XlaCustomCallStatusSetFailure(status, error_message, strlen(error_message));
return;
}
// ... continue.
}
您也可以使用 XlaCustomCallStatusSetSuccess
来指示成功,但 XlaCustomCallStatus
默认处于成功状态,因此完全忽略它也会指示成功。
使用具有此签名的自定义调用函数时,您必须使用适当的 API 版本集创建相应的 custom-call
运算,例如:
xla::CustomCall(&b, "do_custom_call", /*operands=*/{param0, param1},
/*output_shape=*/xla::ShapeUtil::MakeShape(F32, {2048}),
opaque, /*has_side_effect=*/false,
/*output_operand_aliasing=*/{}, /*literal=*/nullptr,
/*schedule=*/xla::CustomCallSchedule::SCHEDULE_NONE,
/*api_version=*/API_VERSION_STATUS_RETURNING);
注:在将来,所有客户端都需要将其自定义调用函数迁移到新的 API 版本,旧版本将被弃用。对于不会失败的自定义调用,您可以简单地添加新的 XlaCustomCallStatus*
参数,然后忽略它。
失败时,不会使用任何自定义调用输出;XLA 运行时将终止计算。HLO 计算无法从错误中恢复(例如通过捕获和处理错误)。
将元组传递给自定义调用
请考虑以下自定义调用。
using xla::ShapeUtil;
using xla::F32;
Shape p0_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {32}),
ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {64}),
ShapeUtil::MakeShape(F32, {128}),
}),
ShapeUtil::MakeShape(F32, {256}),
});
xla::XlaOp p0 = xla::Parameter(0, p0_shape, "p0");
Shape out_shape = ShapeUtil::MakeTuple({
ShapeUtil::MakeShape(F32, {512}),
ShapeUtil::MakeShape(F32, {1024}),
});
xla::CustomCall(&b, "do_custom_call", /*operands=*/{p0}, out_shape);
在 CPU 和 GPU 上,元组在内存中均被表示为指针数组。在 C++ 伪代码中,上述参数 0 的布局如下所示。
// In-memory layout of parameter 0 from custom-call above. True on both CPU
// and GPU.
float* subbuf0 = new float[32];
float* subbuf1 = new float[64];
float* subbuf2 = new float[128]
float* subbuf3 = new float[256];
void* subtuple = new void*[2];
(*subtuple)[0] = subbuf1;
(*subtuple)[1] = subbuf2;
void* p0 = new void*[3];
(*p0)[0] = subbuf0;
(*p0)[1] = subtuple;
(*p0)[2] = subbuf3;
尽管元组的内存中表示在 CPU 和 GPU 中相同,但是在 CPU 和 GPU 自定义调用的调用惯例中,它们的处理方式有所不同。
元组输出作为临时缓冲区
自定义调用的元组输入非常方便,但并非绝对必要。如果我们不支持在自定义调用中使用元组输入,则在将元组传递到自定义调用之前,始终可以使用 get-tuple-element 将它们解包。
另一方面,元组输出可让您完成以其他方式无法完成的任务。
具有元组输出的明显原因是,这是自定义调用(或任何其他 XLA 运算)返回多个独立数组的方式。
但不太明显的是,元组输出也是一种为自定义调用提供临时内存的方式。是的,一个输出可以代表一个临时缓冲区。考虑一下,一个输出缓冲区具有运算可以写入的属性,并且可以在属性被写入后从缓冲区中读取。这正是您想要的临时缓冲区。
在上面的示例中,假设我们要使用 F32[1024]
作为临时缓冲区。随后,我们将像上面那样写入 HLO,这样便永远不会读取自定义调用输出的元组索引 1。
CPU 自定义调用中的元组
在 CPU 代码中,我们有一个函数 do_custom_call(const void** ins, void* out)
。ins
是一个只包含一个元素的数组,此元素指向 param0
。通过解除引用该指针可以访问 param0
的子缓冲区,而通过解除引用 out
则可以访问 output_tuple
的子缓冲区。
GPU 自定义调用中的元组
在 GPU 代码中,我们有一个函数 do_custom_call(..., void** buffers, ...)
。在这种情况下,buffers
是一个包含六个设备指针的主机数组,每个指针对应输入/输出中的每个叶缓冲区。为了生成扁平列表,我们遍历参数和输出,并对每个参数的形状执行先序遍历。具体代码如下:
// Layout of `buffers` parameter to GPU custom call function for custom-call
// above.
buffers[0] == subbuf0
buffers[1] == subbuf1
buffers[2] == subbuf2
buffers[3] == subbuf3
buffers[4] == output_subbuf0
buffers[5] == output_subbuf1