Skip to content

Commit

Permalink
[PHI] move diag_embed op to phi. (PaddlePaddle#44408)
Browse files Browse the repository at this point in the history
* move diag_embed to phi.
  • Loading branch information
ZHUI authored Jul 20, 2022
1 parent 889bdde commit 41f11d2
Show file tree
Hide file tree
Showing 12 changed files with 310 additions and 249 deletions.
93 changes: 10 additions & 83 deletions paddle/fluid/operators/diag_embed_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,89 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/diag_embed_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {

class DiagEmbedOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Input"),
true,
platform::errors::NotFound("Input of DiagEmbedOp is not found."));

PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"),
true,
platform::errors::NotFound("Output of DiagEmbedOp is not found."));

int offset = ctx->Attrs().Get<int>("offset");
int dim1 = ctx->Attrs().Get<int>("dim1");
int dim2 = ctx->Attrs().Get<int>("dim2");

auto x_dims = ctx->GetInputDim("Input");

PADDLE_ENFORCE_GE(
dim1,
-(x_dims.size() + 1),
platform::errors::OutOfRange(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim1));
PADDLE_ENFORCE_LE(
dim1,
x_dims.size(),
platform::errors::OutOfRange(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim1));

PADDLE_ENFORCE_GE(
dim2,
-(x_dims.size() + 1),
platform::errors::OutOfRange(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim2));
PADDLE_ENFORCE_LE(
dim2,
x_dims.size(),
platform::errors::OutOfRange(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim2));

int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1;
int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2;
int offset_ = std::abs(offset);

PADDLE_ENFORCE_NE(dim1_,
dim2_,
platform::errors::InvalidArgument(
"diagonal dimensions should not be identical "
"%ld vs %ld.",
dim1,
dim2));

int new_dim_len = offset_ + x_dims[x_dims.size() - 1];
auto sizes = vectorize(x_dims);
sizes.pop_back();
sizes.insert(sizes.begin() + std::min(dim1_, dim2_), new_dim_len);
sizes.insert(sizes.begin() + std::max(dim1_, dim2_), new_dim_len);
ctx->SetOutputDim("Out", phi::make_ddim(sizes));
}
};

class DiagEmbedOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down Expand Up @@ -131,15 +59,14 @@ class DiagEmbedOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle

namespace ops = paddle::operators;
namespace platform = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(diag_embed,
DiagEmbedInferShapeFunctor,
PD_INFER_META(phi::DiagEmbedInferMeta));

REGISTER_OPERATOR(
diag_embed,
ops::DiagEmbedOp,
ops::DiagEmbedOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(diag_embed,
ops::DiagEmbedKernel<phi::CPUContext, int>,
ops::DiagEmbedKernel<phi::CPUContext, float>,
ops::DiagEmbedKernel<phi::CPUContext, double>,
ops::DiagEmbedKernel<phi::CPUContext, int64_t>);
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
DiagEmbedInferShapeFunctor);
30 changes: 0 additions & 30 deletions paddle/fluid/operators/diag_embed_op.cu

This file was deleted.

130 changes: 0 additions & 130 deletions paddle/fluid/operators/diag_embed_op.h

This file was deleted.

8 changes: 8 additions & 0 deletions paddle/phi/api/yaml/legacy_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,14 @@
func : determinant
backward : det_grad

- api : diag_embed
args : (Tensor x, int offset, int dim1, int dim2)
output : Tensor
infer_meta :
func : DiagEmbedInferMeta
kernel :
func : diag_embed

- api : divide
args : (Tensor x, Tensor y)
output : Tensor
Expand Down
63 changes: 63 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,69 @@ void CumInferMeta(const MetaTensor& x,
out->share_lod(x);
}

void DiagEmbedInferMeta(
const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out) {
auto x_dims = x.dims();

PADDLE_ENFORCE_GE(
dim1,
-(x_dims.size() + 1),
phi::errors::OutOfRange(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim1));
PADDLE_ENFORCE_LE(
dim1,
x_dims.size(),
phi::errors::OutOfRange(
"Dim1 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim1));

PADDLE_ENFORCE_GE(
dim2,
-(x_dims.size() + 1),
phi::errors::OutOfRange(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim2));
PADDLE_ENFORCE_LE(
dim2,
x_dims.size(),
phi::errors::OutOfRange(
"Dim2 is out of range (expected to be in range of [%ld, "
"%ld], but got %ld).",
-(x_dims.size() + 1),
x_dims.size(),
dim2));

int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1;
int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2;
int offset_ = std::abs(offset);

PADDLE_ENFORCE_NE(dim1_,
dim2_,
phi::errors::InvalidArgument(
"diagonal dimensions should not be identical "
"%ld vs %ld.",
dim1,
dim2));

int new_dim_len = offset_ + x_dims[x_dims.size() - 1];
auto sizes = vectorize(x_dims);
sizes.pop_back();
sizes.insert(sizes.begin() + std::min(dim1_, dim2_), new_dim_len);
sizes.insert(sizes.begin() + std::max(dim1_, dim2_), new_dim_len);
out->set_dims(phi::make_ddim(sizes));
out->set_dtype(x.dtype());
}

void DiagInferMeta(const MetaTensor& x,
int offset,
float padding_value,
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ void CumInferMeta(const MetaTensor& x,
bool reverse,
MetaTensor* out);

void DiagEmbedInferMeta(
const MetaTensor& x, int offset, int dim1, int dim2, MetaTensor* out);

void DiagInferMeta(const MetaTensor& x,
int offset,
float padding_value,
Expand Down
Loading

0 comments on commit 41f11d2

Please sign in to comment.