Skip to content

Commit

Permalink
[XPU] interpolate support fp16 (PaddlePaddle#52358)
Browse files Browse the repository at this point in the history
  • Loading branch information
csy0225 authored Mar 31, 2023
1 parent d83d89e commit 3996f0d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
6 changes: 4 additions & 2 deletions paddle/phi/backends/xpu/xpu2_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT16,
phi::DataType::INT32,
phi::DataType::INT64})},
{"bilinear_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"bilinear_interp_v2",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"bilinear_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"bitwise_not", XPUKernelSet({phi::DataType::BOOL})},
{"broadcast", XPUKernelSet({phi::DataType::FLOAT32})},
Expand Down Expand Up @@ -496,7 +497,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64})},
{"multi_encoder_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"nearest_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"nearest_interp_v2",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"not_equal",
XPUKernelSet({phi::DataType::INT64,
Expand Down
42 changes: 26 additions & 16 deletions paddle/phi/kernels/xpu/interpolate_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ void InterpolateKernel(
bool align_corners,
int align_mode,
DenseTensor* output) {
using XPUType = typename XPUTypeTrait<T>::Type;
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
int n, c, in_d, in_h, in_w;
phi::funcs::ExtractNCDWH(x.dims(), data_layout, &n, &c, &in_d, &in_h, &in_w);
Expand Down Expand Up @@ -140,18 +141,19 @@ void InterpolateKernel(
errors::InvalidArgument("XPU nearest is only support NCHW"));
}

int r = xpu::interpolate2d<T>(ctx.x_context(),
x.data<T>(),
output->data<T>(),
n,
c,
in_h,
in_w,
out_h,
out_w,
nearest,
trans_mode,
(data_layout == DataLayout::kNCHW));
int r =
xpu::interpolate2d<XPUType>(ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(output->data<T>()),
n,
c,
in_h,
in_w,
out_h,
out_w,
nearest,
trans_mode,
(data_layout == DataLayout::kNCHW));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "interpolate2d");
}

Expand Down Expand Up @@ -221,14 +223,22 @@ void NearestInterpKernel(

} // namespace phi

PD_REGISTER_KERNEL(
bilinear_interp, XPU, ALL_LAYOUT, phi::BilinearInterpKernel, float) {
PD_REGISTER_KERNEL(bilinear_interp,
XPU,
ALL_LAYOUT,
phi::BilinearInterpKernel,
phi::dtype::float16,
float) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_KERNEL(
nearest_interp, XPU, ALL_LAYOUT, phi::NearestInterpKernel, float) {
PD_REGISTER_KERNEL(nearest_interp,
XPU,
ALL_LAYOUT,
phi::NearestInterpKernel,
phi::dtype::float16,
float) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(3).SetBackend(phi::Backend::ALL_BACKEND);
Expand Down

0 comments on commit 3996f0d

Please sign in to comment.