Skip to content

Commit

Permalink
[PTen]Reshape Kernel Refactor (PaddlePaddle#37164)
Browse files Browse the repository at this point in the history
* reshape kernel refactor

* fix compile bugs when run ci

* support xpu for reshape

* fix bugs when run unittest in kunlun ci

* fix compile bugs when run kunlun

* perfect code according to suggestion
  • Loading branch information
YuanRisheng authored Nov 14, 2021
1 parent 228eb89 commit 895692e
Show file tree
Hide file tree
Showing 19 changed files with 684 additions and 135 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1883,6 +1883,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
pt_kernel_context_->EmplaceBackAttr(
BOOST_GET_CONST(std::vector<int>, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct "
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ static void BuildDygraphPtenKernelContext(
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int>))) {
kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector<int>, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct "
Expand Down
170 changes: 122 additions & 48 deletions paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ limitations under the License. */
#include <string>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/pten_utils.h"

// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h"
namespace paddle {
namespace framework {
class InferShapeContext;
Expand Down Expand Up @@ -248,13 +253,6 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

Expand Down Expand Up @@ -366,13 +364,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");

//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
Expand All @@ -382,42 +373,117 @@ class ReshapeKernel {
void operator()(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output<framework::LoDTensor>("Out");
auto *in = ctx.Input<framework::LoDTensor>("X");

framework::DDim out_dims = out->dims();
// framework::DDim out_dims = out->dims();
auto pt_x = paddle::experimental::MakePtenDenseTensor(*in);

// we can't MakePtenDenseTensor by out, because reshape will realloc memory
// and this will throw error(can't realloc shared memory) in current
// DenseTensor
// design. So, codes below create a tmp densetensor for output.
// TODO(YuanRisheng) we can use MakePtenDenseTensor after #36916 merge.
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
pten::DenseTensorMeta meta{pten::TransToPtenDataType(in->type()),
in->dims(),
pten::TransToPtenDataLayout(in->layout())};
auto pt_out_tmp =
std::make_shared<pten::DenseTensor>(alloc, std::move(meta));
pten::DenseTensor *pt_out = nullptr;
if (in == out) {
pt_out = pt_x.get();
} else {
pt_out = pt_out_tmp.get();
}

auto list_new_shape_tensor =
ctx.MultiInput<framework::Tensor>("ShapeTensor");
auto *shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape")
: nullptr;
if (list_new_shape_tensor.size() > 0) {
// have shape tensor
auto new_shape = get_new_shape(list_new_shape_tensor);
out_dims = ReshapeOp::ValidateShape(new_shape, in->dims());
std::vector<pten::DenseTensor> pt_vec_shape;
for (auto &tensor : list_new_shape_tensor) {
if (platform::is_gpu_place(tensor->place()) ||
platform::is_xpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
pt_vec_shape.push_back(
std::move(*(paddle::experimental::MakePtenDenseTensor(temp))));
} else {
pt_vec_shape.push_back(
std::move(*(paddle::experimental::MakePtenDenseTensor(*tensor))));
}
}
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#endif
} else if (shape_tensor) {
std::unique_ptr<pten::DenseTensor> pt_shape;
if (platform::is_gpu_place(shape_tensor->place()) ||
platform::is_xpu_place(shape_tensor->place())) {
framework::Tensor temp;
TensorCopySync(*shape_tensor, platform::CPUPlace(), &temp);
pt_shape = paddle::experimental::MakePtenDenseTensor(temp);
} else {
pt_shape = paddle::experimental::MakePtenDenseTensor(*shape_tensor);
}

if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#endif
} else {
auto *shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape")
: nullptr;

if (shape_tensor) {
auto *shape_data = shape_tensor->data<int>();
framework::Tensor cpu_shape_tensor;
if (platform::is_gpu_place(shape_tensor->place()) ||
platform::is_xpu_place(shape_tensor->place())) {
TensorCopySync(*shape_tensor, platform::CPUPlace(),
&cpu_shape_tensor);
shape_data = cpu_shape_tensor.data<int>();
}
auto shape =
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
out_dims = ReshapeOp::ValidateShape(shape, in->dims());
auto &shape_vec = ctx.Attr<std::vector<int>>("shape");
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
}
#endif
}
// non-inplace need move all result from pt_out to out, inplace need set
// result dims.
if (in != out) {
paddle::experimental::MovesStorage(pt_out, static_cast<Tensor *>(out));
} else {
out->Resize(pt_out->dims());
}

out->Resize(out_dims);
out->mutable_data(ctx.GetPlace(), in->type());
framework::TensorCopy(
*in, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), out);
out->Resize(out_dims);
}
};

Expand Down Expand Up @@ -479,6 +545,21 @@ class Reshape2Op : public ReshapeOp {

ReshapeOp::InferShape(ctx);
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (multi_inputs.size() > 0) {
return framework::KernelSignature(
"reshape2.mulhost.mid", {"X", "ShapeTensor"}, {}, {"XShape", "Out"});
} else if (ctx.HasInput("Shape")) {
return framework::KernelSignature("reshape2.host.mid", {"X", "Shape"}, {},
{"XShape", "Out"});
} else {
return framework::KernelSignature("reshape2.mid", {"X"}, {"shape"},
{"XShape", "Out"});
}
}
};

class Reshape2OpMaker : public ReshapeOpMaker {
Expand Down Expand Up @@ -557,13 +638,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));

//#ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// framework::DataLayout::kMKLDNN,
// framework::LibraryType::kMKLDNN);
// }
//#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

Expand Down
72 changes: 27 additions & 45 deletions paddle/pten/core/kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,34 +114,16 @@ struct KernelRegistrar {
KernelArgsParseFn args_parse_fn,
KernelArgsDefFn args_def_fn,
KernelFn kernel_fn) {
if (layout == DataLayout::ANY) {
for (size_t layout_iter = static_cast<size_t>(DataLayout::NHWC);
layout_iter != static_cast<size_t>(DataLayout::NUM_DATA_LAYOUTS);
layout_iter++) {
for (size_t dtype = static_cast<size_t>(DataType::BOOL);
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
dtype++) {
ConstructKernel(kernel_name_cstr,
backend,
static_cast<DataLayout>(layout_iter),
static_cast<DataType>(dtype),
args_parse_fn,
args_def_fn,
kernel_fn);
}
}
} else {
for (size_t dtype = static_cast<size_t>(DataType::BOOL);
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
dtype++) {
ConstructKernel(kernel_name_cstr,
backend,
layout,
static_cast<DataType>(dtype),
args_parse_fn,
args_def_fn,
kernel_fn);
}
for (size_t dtype = static_cast<size_t>(DataType::BOOL);
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
dtype++) {
ConstructKernel(kernel_name_cstr,
backend,
layout,
static_cast<DataType>(dtype),
args_parse_fn,
args_def_fn,
kernel_fn);
}
}

Expand All @@ -158,7 +140,6 @@ struct KernelRegistrar {
Kernel kernel(kernel_fn);
args_parse_fn(kernel_key, kernel.mutable_args_def());
args_def_fn(&kernel);

KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name());
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
}
Expand Down Expand Up @@ -838,21 +819,22 @@ struct KernelRegistrar {
_PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, PT_ID, backend, layout, meta_kernel_fn)

#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, func_id, backend, layout, meta_kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \
"PT_REGISTER_KERNEL must be called in global namespace."); \
decltype(meta_kernel_fn) meta_kernel_fn; \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel*); \
static const ::pten::KernelRegistrar __reg_pt_op_kernel_##func_id( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&meta_kernel_fn)>::Parse, \
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \
PT_KERNEL(meta_kernel_fn)); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \
kernel_name, func_id, backend, layout, meta_kernel_fn) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \
"PT_REGISTER_KERNEL must be called in global namespace."); \
decltype(meta_kernel_fn) meta_kernel_fn; \
static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel*); \
static const ::pten::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \
func_id)( \
kernel_name, \
BACKEND(backend), \
DATALAYOUT(layout), \
::pten::KernelArgsParseFunctor<decltype(&meta_kernel_fn)>::Parse, \
&PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \
PT_KERNEL(meta_kernel_fn)); \
void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \
func_id)(::pten::Kernel * kernel)
} // namespace pten
1 change: 1 addition & 0 deletions paddle/pten/core/kernel_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);

/* Output Helpers */

Expand Down
13 changes: 13 additions & 0 deletions paddle/pten/include/manipulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,17 @@ DenseTensor Flatten(const ContextT& dev_ctx,
return dense_out;
}

template <typename T, typename ContextT>
DenseTensor Reshape(const ContextT& dev_ctx,
const DenseTensor& x,
const std::vector<int>& shape) {
auto out_meta = InferShapeFromVecValue(x.meta(), shape);
const auto allocator =
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);
ReshapeFromVectorVal(dev_ctx, x, shape, &dense_out);
return dense_out;
}

} // namespace pten
Loading

0 comments on commit 895692e

Please sign in to comment.