Skip to content

Commit

Permalink
[luci] Add test for ConvertNCHWToNHWC pre/post-reshape (Samsung#7379)
Browse files Browse the repository at this point in the history
This adds a test for ConvertNCHWToNHWC pre/post-reshape.

ONE-DCO-1.0-Signed-off-by: Hyukjin Jeong <[email protected]>
  • Loading branch information
jinevening authored Aug 4, 2021
1 parent b581036 commit 7f6110d
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions compiler/luci/pass/src/ConvertNCHWToNHWCPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,61 @@ class AddGraph final : public SimpleGraph
luci::CircleConst *beta = nullptr;
};

class NHWCReluGraph final : public SimpleGraph
{
protected:
loco::Node *insertGraphBody(loco::Node *input) override
{
relu = g.nodes()->create<luci::CircleRelu>();
pre_reshape = g.nodes()->create<luci::CircleReshape>();
post_reshape = g.nodes()->create<luci::CircleReshape>();
pre_shape = g.nodes()->create<luci::CircleConst>();
post_shape = g.nodes()->create<luci::CircleConst>();

pre_shape->dtype(loco::DataType::S32);
post_shape->dtype(loco::DataType::S32);

uint32_t channel_size = 16;
auto in = loco::must_cast<luci::CircleNode *>(input);
in->shape({1, channel_size, 4, 4});
pre_shape->shape({4});
post_shape->shape({4});

pre_shape->size<loco::DataType::S32>(4);
pre_shape->at<loco::DataType::S32>(0) = 1;
pre_shape->at<loco::DataType::S32>(1) = 4;
pre_shape->at<loco::DataType::S32>(2) = 4;
pre_shape->at<loco::DataType::S32>(3) = channel_size;

post_shape->size<loco::DataType::S32>(4);
post_shape->at<loco::DataType::S32>(0) = 1;
post_shape->at<loco::DataType::S32>(1) = channel_size;
post_shape->at<loco::DataType::S32>(2) = 4;
post_shape->at<loco::DataType::S32>(3) = 4;

pre_reshape->tensor(input);
pre_reshape->shape(pre_shape);

relu->features(pre_reshape);

post_reshape->tensor(relu);
post_reshape->shape(post_shape);

relu->name("Relu");
pre_reshape->name("pre-reshape");
post_reshape->name("post-reshape");

return post_reshape;
}

public:
luci::CircleRelu *relu = nullptr;
luci::CircleReshape *pre_reshape = nullptr;
luci::CircleReshape *post_reshape = nullptr;
luci::CircleConst *pre_shape = nullptr;
luci::CircleConst *post_shape = nullptr;
};

class AddScalarGraph final : public SimpleGraph
{
protected:
Expand Down Expand Up @@ -618,6 +673,22 @@ TEST(ConvertNCHWToNHWC, Add)
check_pre_trans(g.output->from());
}

TEST(ConvertNCHWToNHWC, NHWC_Relu)
{
// Relu is already NHWC, so it should not be converted
// i.e., the graph is not changed
NHWCReluGraph g;
g.init();

run_phase(&g.g, false, false);

EXPECT_EQ(g.pre_reshape, g.relu->features());

auto relu_succs = loco::succs(g.relu);
EXPECT_EQ(1, relu_succs.size());
EXPECT_EQ(g.post_reshape, *relu_succs.begin());
}

TEST(ConvertNCHWToNHWC, AddScalar)
{
AddScalarGraph g;
Expand Down

0 comments on commit 7f6110d

Please sign in to comment.