Skip to content

Commit

Permalink
[PatternLang] Don't rewrite expressions used outside of the pattern (a…
Browse files Browse the repository at this point in the history
…pache#5930)

* Don't rewrite expressions used outside of the pattern

* add comments
  • Loading branch information
Matthew Brookhart authored Jun 26, 2020
1 parent 96bf271 commit e1a1c2a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 22 deletions.
62 changes: 40 additions & 22 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ class PatternGrouper {
return gid_assignments_;
}
/* \brief Group expressions that match the pattern */
const std::vector<Group>& GroupMatches(const DFPattern& pattern, const Expr& pre) {
groups_ = {Group()};
const std::unordered_map<int, Group>& GroupMatches(const DFPattern& pattern, const Expr& pre) {
groups_.clear();
gid_assignments_.clear();

pattern_ = pattern;
Expand All @@ -487,15 +487,17 @@ class PatternGrouper {
for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) {
size_t index = i - 1;
Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_;
if (auto op = current.as<FunctionNode>()) {
if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
pre_partitioned.insert(current);
PostOrderVisit(op->body,
[&pre_partitioned](const Expr& expr) { pre_partitioned.insert(expr); });
if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped
if (auto op = current.as<FunctionNode>()) {
if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
pre_partitioned.insert(current);
PostOrderVisit(op->body,
[&pre_partitioned](const Expr& expr) { pre_partitioned.insert(expr); });
}
}
if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, current)) {
CreateGroup(current);
}
}
if (pre_partitioned.count(current) == 0 && matcher_->Match(pattern_, current)) {
CreateGroup(current);
}
}
}
Expand Down Expand Up @@ -616,20 +618,37 @@ class PatternGrouper {
CHECK(DFPatternMatcher(body).Match(pattern_, body));
group.function = Function(params, body, NullValue<Type>(), Array<TypeVar>());
group.name = extractor.GetName();
// Check to make sure we aren't overlapping with another group
// Check to make sure we aren't overlapping with another group or creating an invalid fusion
// The MatchExtractor will create a new graph by replacing nodes that match the inputs of the
// pattern with the input FunctionVar* Variables. The resulting memoization map will only
// contain nodes in the expression that matched the pattern. If a non-input node of the pattern
// (i.e., some piece of computation) overlaps with the nodes in a previous group, we'll have a
// situation where we try to rewrite the same node twice in the second rewriting or parition
// pass. This isn't valid, so we check for it here. We ignore Ops, functions, and constants
// because they exist more globally outside of the fusion.
for (auto kv : extractor.GetMemo()) {
if (gid_assignments_.count(kv.first) != 0 && inputs.count(kv.first) == 0 &&
kv.first.as<OpNode>() == nullptr && kv.first.as<FunctionNode>() == nullptr &&
kv.first.as<ConstantNode>() == nullptr) {
// Exit due to overlapping partitions
return;
// Similiarly, if interior nodes in a group are used outside of the group fusing to a single
// output would create an invalid graph tranformation, so we block the creation of such groups.
auto memo = extractor.GetMemo();
for (auto kv : memo) {
// Check to ensure that this node isn't an input or a global
if (inputs.count(kv.first) == 0 && kv.first.as<OpNode>() == nullptr &&
kv.first.as<FunctionNode>() == nullptr && kv.first.as<ConstantNode>() == nullptr) {
if (gid_assignments_.count(kv.first) != 0) {
// check to see if the node is use in other groups
// Exit due to overlapping partitions
return;
} else if (kv.second != body) {
// if the node isn't the ouput of the group
auto node = matcher_->expr_graph_.node_map_.at(kv.first);
for (auto* output : node->outputs_) {
// and the node is used by nodes outside of the group
if (memo.count(output->ref_) == 0) {
// Exit because nodes in this pattern's body are used outside the pattern
// fusing it would be invalid
return;
}
}
}
}
}
// Assign Group Ids
Expand All @@ -639,8 +658,7 @@ class PatternGrouper {
}

// Save Group
groups_.emplace_back(std::move(group));
CHECK_EQ(groups_[gid_].gid, gid_);
groups_[group.gid] = std::move(group);
}

/* \brief EmbedConst implements rules for embedding constants into partitioned functions or
Expand Down Expand Up @@ -675,7 +693,7 @@ class PatternGrouper {
}
// Internal State
DFPattern pattern_;
std::vector<Group> groups_;
std::unordered_map<int, Group> groups_;
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_;
DFPatternMatcher* matcher_ = nullptr;
IndexedGraph<DFPattern> pattern_graph_;
Expand Down Expand Up @@ -753,7 +771,7 @@ class PatternRewriter : protected MixedModeMutator {
}

DFPatternCallback callback_;
std::vector<PatternGrouper::Group> groups_;
std::unordered_map<int, PatternGrouper::Group> groups_;
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_;
};

Expand Down Expand Up @@ -805,7 +823,7 @@ class PatternPartitioner : protected MixedModeMutator {
}

Map<String, ObjectRef> attrs_;
std::vector<PatternGrouper::Group> groups_;
std::unordered_map<int, PatternGrouper::Group> groups_;
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_;
PackedFunc check_;
};
Expand Down
31 changes: 31 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,37 @@ def test_partition_double_batchnorm():
reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)
assert tvm.ir.structural_equal(partitioned, reference)

def test_overlappting_partitions():
x = wildcard()
gamma = wildcard()
beta = wildcard()
moving_mean = wildcard()
moving_var = wildcard()
bn_node = is_op('nn.batch_norm')(x, gamma, beta, moving_mean, moving_var)
tuple_get_item_node = TupleGetItemPattern(bn_node, 0)

x = relay.var('x')
var = relay.var('var')
mean = relay.var('mean')
beta = relay.var('beta')
gamma = relay.var('gamma')
BN = relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon=1e-5)
T1 = BN[0]
T2 = BN[0]
add = T1 + T2

assert tuple_get_item_node.partition(add) == add

def test_partition_overused():
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))

x = relay.var('input')
w = relay.var('weight')
conv2d = relay.op.nn.conv2d(x, w)
relu = relay.op.nn.relu(conv2d)
out = relu + conv2d

assert pattern.partition(out) == out

def test_partition_check():
pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard()))
Expand Down

0 comments on commit e1a1c2a

Please sign in to comment.