Skip to content

Commit

Permalink
[luci] Enable FusetInstNorm Ver 5 (Samsung#7331)
Browse files Browse the repository at this point in the history
This will enable fusion of InstanceNorm pattern version 5.

ONE-DCO-1.0-Signed-off-by: SaeHie Park <[email protected]>
  • Loading branch information
seanshpark authored Jul 27, 2021
1 parent faf7c35 commit f027bc9
Showing 1 changed file with 135 additions and 2 deletions.
137 changes: 135 additions & 2 deletions compiler/luci/pass/src/FuseInstanceNormPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,40 @@ namespace
* |
* V
* [Out]
*-------------------------------------------------------------------
* Version_5
* [In]
* |
* V
* +----------- ifm -----+ (reduction indicies)
* | | | |
* | | V V
* | | mean_of_ifm ----------------+
* | V | |
* | sqdiff <--+ (reduction indicies) |
* | | | |
* | V | |
* | mean_as_variance <---+ const_as_epsilon |
* | | | |
* | V | |
* | add_as_variance <--------+ |
* | | |
* | V |
* | rsqrt |
* | | |
* | +--+--+ |
* | | | |
* V V V |
* mul_as_scaled_ifm mul_as_scaled_mean <-------------+
* | |
* | const_as_beta |
* | | V
* | +------> sub
* V |
* add_as_terminal <----------+
* |
* V
* [Out]
*/
class InstanceNormPattern final
{
Expand All @@ -273,6 +307,7 @@ class InstanceNormPattern final
Version_2,
Version_3,
Version_4,
Version_5,
};

InstanceNormPattern(luci::CircleAdd *candidate, PatternVersion pv)
Expand Down Expand Up @@ -566,6 +601,62 @@ template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion:
return true;
}

template <> bool InstanceNormPattern::match<InstanceNormPattern::PatternVersion::Version_5>()
{
CHECK_OR_FALSE(luci::fill(&mul_as_scaled_ifm, &sub).with_commutative_args_of(add_as_terminal));
CHECK_OR_FALSE(luci::fill(&ifm, &rsqrt).with_commutative_args_of(mul_as_scaled_ifm));

auto ifm_circle = loco::must_cast<luci::CircleNode *>(ifm);
CHECK_OR_FALSE(ifm_circle->shape_status() == luci::ShapeStatus::VALID);
CHECK_OR_FALSE(ifm_circle->rank() == 4);
CHECK_OR_FALSE(ifm_circle->dim(3).known());
uint32_t ifm_channel_depth = ifm_circle->dim(3).value();

add_as_variance = dynamic_cast<luci::CircleAdd *>(rsqrt->x());
CHECK_OR_FALSE(add_as_variance);

CHECK_OR_FALSE(
luci::fill(&mean_as_variance, &const_as_epsilon).with_commutative_args_of(add_as_variance));

CHECK_OR_FALSE(const_as_epsilon->dtype() == loco::DataType::FLOAT32);
// TODO Support regarding broadcast
CHECK_OR_FALSE(const_as_epsilon->size<loco::DataType::FLOAT32>() == 1);

CHECK_OR_FALSE(is_instance_mean_v1(mean_as_variance));

sqdiff = dynamic_cast<luci::CircleSquaredDifference *>(mean_as_variance->input());
CHECK_OR_FALSE(sqdiff);

loco::Node *ifm_should_be = nullptr;
CHECK_OR_FALSE(luci::fill(&ifm_should_be, &mean_of_ifm).with_commutative_args_of(sqdiff));
CHECK_OR_FALSE(ifm == ifm_should_be);
CHECK_OR_FALSE(is_instance_mean_v1(mean_of_ifm));
CHECK_OR_FALSE(ifm == mean_of_ifm->input());

const_as_beta = dynamic_cast<luci::CircleConst *>(sub->x());
CHECK_OR_FALSE(const_as_beta);
CHECK_OR_FALSE(is_1D_with_dummy_dim(const_as_beta, ifm_channel_depth));

luci::CircleRsqrt *rsqrt_should_be = nullptr;
luci::CircleMean *mean_of_ifm_should_be = nullptr;

mul_as_scaled_mean = dynamic_cast<luci::CircleMul *>(sub->y());
CHECK_OR_FALSE(mul_as_scaled_mean);
CHECK_OR_FALSE(luci::fill(&rsqrt_should_be, &mean_of_ifm_should_be)
.with_commutative_args_of(mul_as_scaled_mean));
CHECK_OR_FALSE(rsqrt == rsqrt_should_be);
CHECK_OR_FALSE(mean_of_ifm == mean_of_ifm_should_be);

// mul_gamma is absent
// const_as_gamma assume to be 1.0
auto graph = add_as_terminal->graph();
const_as_gamma = make_const_one(graph, 1.0f);
const_as_gamma->name(add_as_terminal->name() + "/gamma");

_matched = true;
return true;
}

bool InstanceNormPattern::matched()
{
if (_matched)
Expand All @@ -583,6 +674,8 @@ bool InstanceNormPattern::matched()
return match<PatternVersion::Version_3>();
case PatternVersion::Version_4:
return match<PatternVersion::Version_4>();
case PatternVersion::Version_5:
return match<PatternVersion::Version_5>();

default:
break;
Expand Down Expand Up @@ -767,6 +860,31 @@ template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Ve
replace(_p.div).with(instance_norm);
}

template <> void FuseInstanceNorm::apply<InstanceNormPattern::PatternVersion::Version_5>()
{
auto graph = _p.add_as_terminal->graph();

reshape_gamma_beta();

auto instance_norm = create_inst_norm(graph);

// set origin
std::vector<std::shared_ptr<luci::CircleNodeOrigin>> origin_vec{
luci::get_origin(_p.mean_of_ifm),
luci::get_origin(_p.sqdiff),
luci::get_origin(_p.mean_as_variance),
luci::get_origin(_p.add_as_variance),
luci::get_origin(_p.rsqrt),
luci::get_origin(_p.mul_as_scaled_ifm),
luci::get_origin(_p.mul_as_scaled_mean),
luci::get_origin(_p.sub),
luci::get_origin(_p.add_as_terminal)};

luci::add_origin(instance_norm, luci::composite_origin(origin_vec));

replace(_p.add_as_terminal).with(instance_norm);
}

void FuseInstanceNorm::apply()
{
assert(_p.matched());
Expand All @@ -785,6 +903,9 @@ void FuseInstanceNorm::apply()
case InstanceNormPattern::PatternVersion::Version_4:
apply<InstanceNormPattern::PatternVersion::Version_4>();
break;
case InstanceNormPattern::PatternVersion::Version_5:
apply<InstanceNormPattern::PatternVersion::Version_5>();
break;

default:
break;
Expand Down Expand Up @@ -942,7 +1063,19 @@ bool fuse_instance_norm(luci::CircleAdd *add)
return true;
}

if (pv == InstanceNormPattern::PatternVersion::Version_2)
if (pv == InstanceNormPattern::PatternVersion::Version_1)
{
// if Version_1 failed, try with Version_5
pv = InstanceNormPattern::PatternVersion::Version_5;
InstanceNormPattern pattern(add, pv);
if (pattern.matched())
{
FuseInstanceNorm fuse(pattern);
fuse.apply();
return true;
}
}
else if (pv == InstanceNormPattern::PatternVersion::Version_2)
{
// if Version_2 failed, try with Version_3
pv = InstanceNormPattern::PatternVersion::Version_3;
Expand Down Expand Up @@ -989,7 +1122,7 @@ bool FuseInstanceNormPass::run(loco::Graph *g)
{
bool changed = false;

// Check Version_1, Version_2, Version_3
// Check Version_1, Version_2, Version_3, Version_5
for (auto node : loco::active_nodes(loco::output_nodes(g)))
{
auto add = dynamic_cast<luci::CircleAdd *>(node);
Expand Down

0 comments on commit f027bc9

Please sign in to comment.