From e300929924a61249df79ede92ce0dcf51a811d01 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 23 Mar 2023 14:20:07 +0800 Subject: [PATCH] [CustomOP Supports Inplace] Add chapter for inplace mechanism (#5745) * [CustomOP Supports Inplace] Add chapter for inplace mechanism * polish typo format --- docs/guides/custom_op/new_cpp_op_cn.md | 372 ++++++++++++++++++------- 1 file changed, 266 insertions(+), 106 deletions(-) diff --git a/docs/guides/custom_op/new_cpp_op_cn.md b/docs/guides/custom_op/new_cpp_op_cn.md index 9652cf55829..a75666c2251 100644 --- a/docs/guides/custom_op/new_cpp_op_cn.md +++ b/docs/guides/custom_op/new_cpp_op_cn.md @@ -831,112 +831,6 @@ std::vector relu_cuda_backward(const paddle::Tensor& x, } ``` -#### 对自定义设备的支持 - -首先请参考 [新硬件接入示例](https://www.paddlepaddle.org.cn/documentation/docs/zh/dev_guides/custom_device_docs/custom_device_example_cn.html) 确保自定义设备已经注册完成。 - -如果 CPU 实现和 GPU 实现无法满足新硬件的需求,可以通过组合 C++ 运算 API 的方式,实现自定义算子。将前述 `relu_cpu.cc` 中的 CPU 实现改为组合 C++ 运算 API 的示例如下: - -```c++ -#include "paddle/extension.h" - -#include - -#define CHECK_CUSTOM_INPUT(x) PD_CHECK(x.is_custom_device(), #x " must be a custom Tensor.") - -std::vector relu_custom_forward(const paddle::Tensor& x) { - CHECK_CUSTOM_INPUT(x); - auto out = paddle::relu(x); - return {out}; -} - -std::vector relu_custom_backward( - const paddle::Tensor& x, - const paddle::Tensor& out, - const paddle::Tensor& grad_out) { - CHECK_CUSTOM_INPUT(x); - CHECK_CUSTOM_INPUT(out); - auto grad_x = paddle::empty_like(x, x.dtype(), x.place()); - auto ones = paddle::experimental::full_like(x, 1.0, x.dtype(), x.place()); - auto zeros = paddle::experimental::full_like(x, 0.0, x.dtype(), x.place()); - auto condition = paddle::experimental::greater_than(x, zeros); - - grad_x = paddle::multiply(grad_out, paddle::where(condition, ones, zeros)); - - return {grad_x}; -} - -std::vector relu_custom_double_backward( - const paddle::Tensor& out, const paddle::Tensor& ddx) { - CHECK_CUSTOM_INPUT(out); - auto ddout = paddle::empty(out.shape(), out.dtype(), out.place()); - auto ones = paddle::experimental::full_like(out, 1.0, out.dtype(), out.place()); - auto zeros = paddle::experimental::full_like(out, 0.0, out.dtype(), out.place()); - auto condition = paddle::experimental::greater_than(out, zeros); - - ddout = paddle::multiply(ddx, paddle::where(condition, ones, zeros)); - - return {ddout}; -} - -std::vector ReluForward(const paddle::Tensor& x) { - if (x.is_cpu()) { - return relu_cpu_forward(x); - } else if (x.is_custom_device()) { - return relu_custom_forward(x); - } else { - PD_THROW("Not implemented."); - } -} - -std::vector ReluBackward(const paddle::Tensor& x, - const paddle::Tensor& out, - const paddle::Tensor& grad_out) { - if (x.is_cpu()) { - return relu_cpu_backward(x, out, grad_out); - } else if (x.is_custom_device()) { - return relu_custom_backward(x, out, grad_out); - } else { - PD_THROW("Not implemented."); - } -} - -std::vector ReluDoubleBackward(const paddle::Tensor& out, - const paddle::Tensor& ddx) { - if (out.is_cpu()) { - return relu_cpu_double_backward(out, ddx); - } else if (out.is_custom_device()) { - return relu_custom_double_backward(out, ddx); - } else { - PD_THROW("Not implemented."); - } -} -``` - -支持的 C++ 运算 API 可参考 [类 Python 的 C++运算 API](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/custom_op/new_cpp_op_cn.html#python-c-api) - -##### 获取自定义设备的 stream - -用户想要获取设备的 `stream` 时,可以通过下述方式获取对应 `Tensor` 的 `stream`(需要添加头文件 `#include "paddle/phi/backends/all_context.h"`): - -```c++ -#include "paddle/extension.h" -#include "paddle/phi/backends/all_context.h" - -#define CHECK_CUSTOM_INPUT(x) \ - PD_CHECK(x.is_custom_device(), #x " must be a custom Tensor.") - -void* GetStream(const paddle::Tensor& x) { - CHECK_CUSTOM_INPUT(x); - - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(x.place()); - auto custom_ctx = static_cast(dev_ctx); - void* stream = custom_ctx->stream(); - PD_CHECK(stream != nullptr); - - return stream; -} -``` ### 维度与类型推导函数实现 @@ -1159,6 +1053,272 @@ std::vector AttrTestBackward( const std::vector& c) {...} ``` +### 其他功能 + +#### 支持自定义设备 + +首先请参考 [新硬件接入示例](https://www.paddlepaddle.org.cn/documentation/docs/zh/dev_guides/custom_device_docs/custom_device_example_cn.html) 确保自定义设备已经注册完成。 + +如果 CPU 实现和 GPU 实现无法满足新硬件的需求,可以通过组合 C++ 运算 API 的方式,实现自定义算子。将前述 `relu_cpu.cc` 中的 CPU 实现改为组合 C++ 运算 API 的示例如下: + +```c++ +#include "paddle/extension.h" + +#include + +#define CHECK_CUSTOM_INPUT(x) PD_CHECK(x.is_custom_device(), #x " must be a custom Tensor.") + +std::vector relu_custom_forward(const paddle::Tensor& x) { + CHECK_CUSTOM_INPUT(x); + auto out = paddle::relu(x); + return {out}; +} + +std::vector relu_custom_backward( + const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + CHECK_CUSTOM_INPUT(x); + CHECK_CUSTOM_INPUT(out); + auto grad_x = paddle::empty_like(x, x.dtype(), x.place()); + auto ones = paddle::experimental::full_like(x, 1.0, x.dtype(), x.place()); + auto zeros = paddle::experimental::full_like(x, 0.0, x.dtype(), x.place()); + auto condition = paddle::experimental::greater_than(x, zeros); + + grad_x = paddle::multiply(grad_out, paddle::where(condition, ones, zeros)); + + return {grad_x}; +} + +std::vector relu_custom_double_backward( + const paddle::Tensor& out, const paddle::Tensor& ddx) { + CHECK_CUSTOM_INPUT(out); + auto ddout = paddle::empty(out.shape(), out.dtype(), out.place()); + auto ones = paddle::experimental::full_like(out, 1.0, out.dtype(), out.place()); + auto zeros = paddle::experimental::full_like(out, 0.0, out.dtype(), out.place()); + auto condition = paddle::experimental::greater_than(out, zeros); + + ddout = paddle::multiply(ddx, paddle::where(condition, ones, zeros)); + + return {ddout}; +} + +std::vector ReluForward(const paddle::Tensor& x) { + if (x.is_cpu()) { + return relu_cpu_forward(x); + } else if (x.is_custom_device()) { + return relu_custom_forward(x); + } else { + PD_THROW("Not implemented."); + } +} + +std::vector ReluBackward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + if (x.is_cpu()) { + return relu_cpu_backward(x, out, grad_out); + } else if (x.is_custom_device()) { + return relu_custom_backward(x, out, grad_out); + } else { + PD_THROW("Not implemented."); + } +} + +std::vector ReluDoubleBackward(const paddle::Tensor& out, + const paddle::Tensor& ddx) { + if (out.is_cpu()) { + return relu_cpu_double_backward(out, ddx); + } else if (out.is_custom_device()) { + return relu_custom_double_backward(out, ddx); + } else { + PD_THROW("Not implemented."); + } +} +``` + +支持的 C++ 运算 API 可参考 [类 Python 的 C++运算 API](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/custom_op/new_cpp_op_cn.html#python-c-api) + +##### 获取自定义设备的 stream + +用户想要获取设备的 `stream` 时,可以通过下述方式获取对应 `Tensor` 的 `stream`(需要添加头文件 `#include "paddle/phi/backends/all_context.h"`): + +```c++ +#include "paddle/extension.h" +#include "paddle/phi/backends/all_context.h" + +#define CHECK_CUSTOM_INPUT(x) \ + PD_CHECK(x.is_custom_device(), #x " must be a custom Tensor.") + +void* GetStream(const paddle::Tensor& x) { + CHECK_CUSTOM_INPUT(x); + + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(x.place()); + auto custom_ctx = static_cast(dev_ctx); + void* stream = custom_ctx->stream(); + PD_CHECK(stream != nullptr); + + return stream; +} +``` + +#### inplace 机制 + +使用 inplace 机制定义的自定义算子,可以指定输入和输出使用同一个 Tensor,或者对输入的 Tensor 做原位修改。 + +下面结合具体的使用示例进行介绍,将 `relu` 算子改写为 inplace 算子,函数实现如下: + +```c++ +#include "paddle/extension.h" + +#include + +template +void relu_forward_kernel(data_t* x_data, int64_t numel) { + for (size_t i = 0; i < numel; ++i) { + x_data[i] = x_data[i] > 0 ? x_data[i] : 0; + } +} + +template +void relu_backward_kernel(const data_t* out_data, + data_t* grad_out_data, + int64_t out_numel) { + for (int64_t i = 0; i < out_numel; ++i) { + grad_out_data[i] = + grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); + } +} + +void ReluCpuInplaceForward(paddle::Tensor& x) { // NOLINT + PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); + + PD_DISPATCH_FLOATING_TYPES(x.type(), "ReluForward", ([&] { + relu_forward_kernel(x.data(), + x.size()); + })); +} + +void ReluCpuInplaceBackward(const paddle::Tensor& x, + const paddle::Tensor& out, + paddle::Tensor& grad_out) { // NOLINT + PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); + + PD_DISPATCH_FLOATING_TYPES( + grad_out.type(), "ReluBackward", ([&] { + relu_backward_kernel( + out.data(), grad_out.data(), grad_out.size()); + })); +} + +PD_BUILD_OP(custom_inplace_relu) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetInplaceMap({{"X", "Out"}}) + .SetKernelFn(PD_KERNEL(ReluCpuInplaceForward)); + +PD_BUILD_GRAD_OP(custom_inplace_relu) + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetInplaceMap({{paddle::Grad("Out"), paddle::Grad("X")}}) + .SetKernelFn(PD_KERNEL(ReluCpuInplaceBackward)); +``` + +相比于 `relu` 算子的常规实现,使用 inplace 机制需要注意以下几点: + +1. 输入的 inplace Tensor 类型,应该修改为 `paddle::Tensor&` 而非 `const paddle::Tensor&`; + +2. 定义算子时,需要使用 `SetInplaceMap` 指明输入和输出间 inplace 的映射关系。`SetInplaceMap` 传入的参数类型为 `std::unordered_map`,支持多组输入和输出之间进行 inplace 映射。例如可以定义: `.SetInplaceMap({{"X", "Out1"}, {"Y", "Out2"}})`; + +3. 一方面,做 inplace 映射的输出 Tensor,不再作为函数的返回值,如果此时函数没有需要返回的 Tensor,函数的输出类型应为 `void` ;另一方面,其他没有做 inplace 映射的输出 Tensor,仍需作为返回值显式输出,此时函数的输出类型仍为 `std::vector`。例如 `ReluCpuInplaceForward` 函数中不再显式输出 Tensor,因此函数返回类型为 `void`; + +4. 框架会对算子的输入、输出映射做基本的正确性检查(`SetInplaceMap`中指定的输入 Tensor 命名与 `Inputs` 中定义的名称一致;输出 Tensor 命名与 `Outputs` 中定义的名称一致),因此 `SetInplaceMap` 必须在 `Inputs` 和 `Outputs` 之后指定。 + +下面以一个自定义的 inplace `custom_add` 加法实现为例,来对上述的注意事项进行介绍: + + +```c++ +#include "paddle/extension.h" + +#include + +template +void add_forward_kernel(data_t* x_data, const data_t* y_data, int64_t numel) { + for (size_t i = 0; i < numel; ++i) { + x_data[i] += y_data[i]; + } +} + +template +void add_backward_kernel(data_t* y_grad_data, + const data_t* out_grad_data, + int64_t numel) { + for (size_t i = 0; i < numel; ++i) { + y_grad_data[i] = out_grad_data[i]; + } +} + +// 有 inplace 映射的输出 Tensor,不再作为函数的返回值,如果此时函数没有需要返回的 Tensor,函数的输出类型应为 `void` +void AddForward(paddle::Tensor& x, // 输入的 inplace Tensor 类型,应该修改为 `paddle::Tensor&` + const paddle::Tensor& y) { + PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); + + PD_DISPATCH_FLOATING_TYPES(x.type(), "AddForward", ([&] { + add_forward_kernel(x.data(), + y.data(), + x.size()); + })); + // 输出 Tensor out 指定了 inplace 映射,因此不需要显式的返回 +} + +// InferDtype 函数的输入类型不需要做特别修改 +std::vector AddInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& y_dtype) { + return {x_dtype}; +} + +// InferShape 函数的输入类型不需要做特别修改 +std::vector> AddInferShape( + const std::vector& x_shape, const std::vector& y_shape) { + return {x_shape}; +} + +// 没有做 inplace 映射的输出 Tensor,仍需作为返回值显式输出,此时函数的输出类型仍为 std::vector +std::vector AddBackward(const paddle::Tensor& x, + const paddle::Tensor& y, + paddle::Tensor& out_grad) { // 输入的 inplace Tensor 类型,应该修改为 `paddle::Tensor&` + PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); + PD_CHECK(y.place() == paddle::PlaceType::kCPU, "y must be a CPU Tensor."); + + paddle::Tensor y_grad = paddle::empty(x.shape(), x.dtype(), x.place()); + + PD_DISPATCH_FLOATING_TYPES( + out_grad.type(), "AddBackward", ([&] { + add_backward_kernel( + y_grad.data(), out_grad.data(), out_grad.size()); + })); + + // y_grad 没有指定 inplace 映射,因此仍然需要显式的作为返回值 + return {y_grad}; +} + +PD_BUILD_OP(custom_add) + .Inputs({"X", "Y"}) + .Outputs({"Out"}) + .SetInplaceMap({{"X", "Out"}}) // 使用 `SetInplaceMap` 指明输入和输出间 inplace 的映射关系 + .SetKernelFn(PD_KERNEL(AddForward)) + .SetInferShapeFn(PD_INFER_SHAPE(AddInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(AddInferDtype)); + +PD_BUILD_GRAD_OP(custom_add) + .Inputs({"X", "Y", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X"), paddle::Grad("Y")}) + .SetInplaceMap({{paddle::Grad("Out"), paddle::Grad("X")}}) // `SetInplaceMap` 必须在 `Inputs` 和 `Outputs` 之后指定 + .SetKernelFn(PD_KERNEL(AddBackward)); + +``` + + ## 自定义算子编译与使用 本机制提供了两种编译自定义算子的方式,分别为 **使用 `setuptools` 编译** 与 **即时编译** ,下面依次通过示例介绍。