Skip to content

Commit

Permalink
Extend AttrPattern to support CallNode and FunctionNode attributes (a…
Browse files Browse the repository at this point in the history
…pache#5637)

* Extend AttrPattern to support CallNode and FunctionNode attributes

* Update tutorial and add breaks

* add func attr test
  • Loading branch information
Matthew Brookhart authored May 21, 2020
1 parent 019da5d commit cafb498
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 46 deletions.
4 changes: 2 additions & 2 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ There are quite a few properties that are worth matching of operators below we e
The next example is a dense operation with any operator that is marked element-wise::

def test_no_match_attr():
op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE)
op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE})
op_pat = op(wildcard(), wildcard())
x = relay.var('x')
y = relay.var('y')
Expand Down Expand Up @@ -97,7 +97,7 @@ The high level design is to introduce a language of patterns for now we propose
| *
| pattern(pattern1, ... patternN)
| has_type(pattern, type)
| has_attr(pattern, attr, attr_value)
| has_attr(pattern, attrs)
| is_input(name)
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)
Expand Down
21 changes: 9 additions & 12 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,20 @@ def __mul__(self, other):
def __truediv__(self, other):
return is_op("divide")(self, other)

def has_attr(self, attr_name: str, attr_value):
def has_attr(self, attrs):
"""
Add an attribute constraint to this pattern
Parameters
----------
attr_name: str
The name of the attribute to match
attr_value: Any
The value of the attribute to match
attrs: Dict[str, Object]
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting AttrPattern
"""
attrs = make_node("DictAttrs", **{attr_name: attr_value})
attrs = make_node("DictAttrs", **attrs)
return AttrPattern(self, attrs)

def has_type(self, ttype):
Expand Down Expand Up @@ -237,26 +234,26 @@ def has_type(ttype, pattern: DFPattern = None) -> DFPattern:
return TypePattern(pattern, ttype)


def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern:
def has_attr(attrs, pattern=None) -> DFPattern:
"""
Syntatic sugar for creating an AttrPattern
Parameters
----------
pattern: tvm.relay.dataflow_pattern.DFPattern
The input pattern.
attrs: tvm.Attrs
attrs: Dict[str, Object]
The attributes to match
pattern: Optional[tvm.relay.dataflow_pattern.DFPattern]
The input pattern.
Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting AttrPattern
"""
if pattern is None:
pattern = wildcard()
return pattern.has_attr(attr_name, attr_value)
return pattern.has_attr(attrs)


def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern:
Expand Down
76 changes: 55 additions & 21 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,39 +101,73 @@ bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& exp
return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
}

bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
switch (rhs.type_code()) {
case kDLInt:
if (auto* val = lhs.as<IntImmNode>()) {
return val->value == rhs.operator int64_t();
}
break;
case kDLFloat:
if (auto* val = lhs.as<FloatImmNode>()) {
return val->value == rhs.operator double();
}
break;
case kTVMStr:
std::cout << lhs << std::endl;
if (auto* val = lhs.as<tir::StringImmNode>()) {
return val->value == rhs.operator std::string();
} else if (auto* val = lhs.as<StringObj>()) {
return val->data == rhs.operator std::string();
}
break;
default:
CHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code();
}
return false;
}

bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
bool matches = false;
auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
if (const auto* op_node = expr.as<OpNode>()) {
Op op = GetRef<Op>(op_node);
auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
for (auto kv : attributes) {
auto attr_name = kv.first;
auto attr_value = kv.second;
auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
if (op_map.count(op)) {
switch (op_map[op].type_code()) {
case kDLInt:
if (auto* val = kv.second.as<IntImmNode>()) {
matches = val->value == op_map[op].operator int64_t();
}
break;
case kDLFloat:
if (auto* val = kv.second.as<FloatImmNode>()) {
matches = val->value == op_map[op].operator double();
}
break;
case kTVMStr:
if (auto* val = kv.second.as<tir::StringImmNode>()) {
matches = val->value == op_map[op].operator std::string();
}
break;
default:
CHECK(false) << "Unsupported type in Type Pattern Node";
}
matches = MatchRetValue(attr_value, op_map[op]);
}
}
} else if (auto* op = expr.as<CallNode>()) {
matches = true;
// TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this
// and replace the whole thing with a Visitor-based approach
ReflectionVTable* reflection = ReflectionVTable::Global();
auto attrs_node = const_cast<Object*>(op->attrs.get());
auto attr_names = reflection->ListAttrNames(attrs_node);
for (auto kv : attributes) {
if (matches &&
std::find(attr_names.begin(), attr_names.end(), kv.first) != attr_names.end()) {
matches &= MatchRetValue(kv.second, reflection->GetAttr(attrs_node, kv.first));
} else {
matches = false;
break;
}
}
} else if (auto* op = expr.as<FunctionNode>()) {
matches = true;
for (auto kv : attributes) {
if (matches && op->attrs->dict.count(kv.first)) {
matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]);
} else {
matches = false;
break;
}
}
}
return matches;
return matches && VisitDFPattern(attr_pattern->pattern, expr);
}

Array<DFPattern> reverse(const Array<DFPattern>& args) {
Expand Down
60 changes: 49 additions & 11 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_TypePattern():
assert ty_pat.type == ttype

def test_AttrPattern():
op = is_op('add').has_attr("TOpPattern", K_ELEMWISE)
op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE})
assert isinstance(op, AttrPattern)
assert op.attrs["TOpPattern"] == K_ELEMWISE

Expand Down Expand Up @@ -225,19 +225,57 @@ def test_no_match_type():
ty_pat = has_type(relay.TensorType((10, 10), "float32"))
assert not ty_pat.match(x)

def test_match_attr():
op = is_op('add').has_attr("TOpPattern", K_BROADCAST)
def test_match_op_attr():
op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
op_pat = op(wildcard(), wildcard())
x = relay.var('x')
y = relay.var('y')
assert op_pat.match(x + y)

def test_no_match_attr():
op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE)
def test_no_match_op_attr():
op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE})
op_pat = op(wildcard(), wildcard())
x = relay.var('x')
y = relay.var('y')
assert not op_pat.match(relay.op.nn.dense(x, y))
op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
op_pat = op(wildcard(), wildcard())
x = relay.var('x')
y = relay.var('y')
assert not op_pat.match(x - y)

def test_match_func_attr():
pattern = wildcard().has_attr({"Composite": "add"})
x = relay.var('x')
y = relay.var('y')
f = relay.Function([x, y], x + y).with_attr("Composite", "add")
assert pattern.match(f)

def test_no_match_func_attr():
pattern = wildcard().has_attr({"Composite": "add"})
x = relay.var('x')
y = relay.var('y')

f = relay.Function([x, y], x + y).with_attr("RandomTest", "add")
assert not pattern.match(f)
f = relay.Function([x, y], x + y).with_attr("Composite", "conv_bias")
assert not pattern.match(f)

def test_match_call_attr():
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"})
x = relay.var('x')
y = relay.var('y')
assert is_conv2d.match(relay.op.nn.conv2d(x, y))

def test_no_match_call_attr():
x = relay.var('x')
y = relay.var('y')

is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NHWC"})
assert not is_conv2d.match(relay.op.nn.conv2d(x, y))

is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"})
assert not is_conv2d.match(relay.op.nn.conv2d(x, y))

def test_match_diamond():
# Pattern
Expand Down Expand Up @@ -301,7 +339,7 @@ def test_match_fake_diamond():
def test_match_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

Expand Down Expand Up @@ -344,7 +382,7 @@ def test_match_dominator():

# Fuzzy path/nested Diamond
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

Expand All @@ -361,7 +399,7 @@ def test_match_dominator():

def test_not_match_dominator():
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

Expand Down Expand Up @@ -578,7 +616,7 @@ def __init__(self):
self.weight = wildcard()

is_conv2d = is_op('nn.conv2d')(self.inp, self.weight)
is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
reduction = is_op('add')(wildcard(), wildcard())
self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction)

Expand Down Expand Up @@ -740,7 +778,7 @@ def test_double_partition():
def test_partition_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

Expand All @@ -765,7 +803,7 @@ def generate_diamond(inp, weight):
def test_quadruple_partition_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

Expand Down

0 comments on commit cafb498

Please sign in to comment.