Skip to content

Commit

Permalink
[PHI] Move segment_pool to phi. (PaddlePaddle#40099)
Browse files Browse the repository at this point in the history
* move segment_pool to phi.

* mark summed ids as optional tensor.

* fix as reviews.
  • Loading branch information
ZHUI authored Mar 10, 2022
1 parent 548f2be commit a07f19e
Show file tree
Hide file tree
Showing 22 changed files with 666 additions and 405 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEP

set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows_utils lapack_function
lod_tensor maxouting unpooling pooling lod_rank_table context_project
sequence_pooling segment_pooling executor device_memory_aligment generator)
sequence_pooling executor device_memory_aligment generator)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse matrix_solve)
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/operators/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ math_library(vol2col)
math_library(prelu)
math_library(bert_encoder_functor)
math_library(tree2col DEPS math_function)
math_library(segment_pooling)
math_library(matrix_solve)

cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
Expand Down
37 changes: 9 additions & 28 deletions paddle/fluid/operators/segment_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/segment_pool_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"

namespace paddle {
namespace operators {
Expand All @@ -23,22 +26,6 @@ class SegmentPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SegmentPool");
OP_INOUT_CHECK(ctx->HasInput("SegmentIds"), "Input", "SegmentIds",
"SegmentPool");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SegmentPool");
auto dims = ctx->GetInputDim("X");
dims[0] = -1;
ctx->SetOutputDim("Out", dims);

if (ctx->Attrs().Get<std::string>("pooltype") == "MEAN") {
OP_INOUT_CHECK(ctx->HasOutput("SummedIds"), "Output", "SummedIds",
"SegmentPool");
ctx->SetOutputDim("SummedIds", {-1, 1});
}
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
Expand Down Expand Up @@ -150,17 +137,11 @@ class SegmentPoolGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(segment_pool, SegmentPoolInferShapeFunctor,
PD_INFER_META(phi::SegmentPoolInferMeta));

REGISTER_OPERATOR(segment_pool, ops::SegmentPoolOp, ops::SegmentPoolOpMaker,
ops::SegmentPoolGradOpMaker<paddle::framework::OpDesc>,
ops::SegmentPoolGradOpMaker<paddle::imperative::OpBase>);
ops::SegmentPoolGradOpMaker<paddle::imperative::OpBase>,
SegmentPoolInferShapeFunctor);
REGISTER_OPERATOR(segment_pool_grad, ops::SegmentPoolGradOp);

REGISTER_OP_CPU_KERNEL(
segment_pool,
ops::SegmentPoolKernel<paddle::platform::CPUDeviceContext, float>,
ops::SegmentPoolKernel<paddle::platform::CPUDeviceContext, double>);

REGISTER_OP_CPU_KERNEL(
segment_pool_grad,
ops::SegmentPoolGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SegmentPoolGradKernel<paddle::platform::CPUDeviceContext, double>);
27 changes: 0 additions & 27 deletions paddle/fluid/operators/segment_pool_op.cu

This file was deleted.

176 changes: 0 additions & 176 deletions paddle/fluid/operators/segment_pool_op.h

This file was deleted.

4 changes: 1 addition & 3 deletions paddle/fluid/operators/unity_build_rule.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ register_unity_group(cc
scatter_nd_add_op.cc
scatter_op.cc
seed_op.cc
segment_pool_op.cc
select_input_op.cc
select_output_op.cc)
register_unity_group(cc
Expand Down Expand Up @@ -496,8 +495,7 @@ register_unity_group(cu
scale_op.cu
scatter_nd_add_op.cu
scatter_op.cu
seed_op.cu
segment_pool_op.cu)
seed_op.cu)
register_unity_group(cu
roi_pool_op.cu
selu_op.cu
Expand Down
19 changes: 19 additions & 0 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,25 @@ void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
out->share_meta(x);
}

void SegmentPoolInferMeta(const MetaTensor& x,
const MetaTensor& segment_ids,
const std::string& pooltype,
MetaTensor* out,
MetaTensor* summed_ids,
MetaConfig config) {
auto dims = x.dims();
dims[0] = -1;
out->set_dims(dims);
out->set_dtype(x.dtype());
out->set_layout(x.layout());

if (pooltype == "MEAN") {
summed_ids->set_dims({-1, 1});
summed_ids->set_dtype(x.dtype());
summed_ids->set_layout(x.layout());
}
}

void BCELossInferMeta(const MetaTensor& input,
const MetaTensor& label,
MetaTensor* out,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ void CrossInferMeta(const MetaTensor& x,
MetaTensor* out);

void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);

void SegmentPoolInferMeta(const MetaTensor& x,
const MetaTensor& segment_ids,
const std::string& pooltype,
MetaTensor* out,
MetaTensor* summed_ids,
MetaConfig config = MetaConfig());

void BCELossInferMeta(const MetaTensor& input,
const MetaTensor& label,
MetaTensor* out,
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel)
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel segment_pool_kernel segment_pool_grad_kernel)
kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel)
kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
Expand All @@ -39,6 +39,8 @@ kernel_library(put_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scat
kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function)
kernel_library(segment_pool_kernel DEPS ${COMMON_KERNEL_DEPS} segment_pooling)
kernel_library(segment_pool_grad_kernel DEPS ${COMMON_KERNEL_DEPS} segment_pooling)

# 4. auto parse and build kernel targets by cmake
register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} )
Expand Down
26 changes: 26 additions & 0 deletions paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/segment_pool_grad_kernel.h"
#include "paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

PD_REGISTER_KERNEL(segment_pool_grad,
CPU,
ALL_LAYOUT,
phi::SegmentPoolGradKernel,
float,
double) {}
Loading

0 comments on commit a07f19e

Please sign in to comment.