Skip to content

Commit

Permalink
[RELAY] Add occurs check before unification (apache#2012)
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic authored and tqchen committed Oct 27, 2018
1 parent a0c813b commit 4fbb7c8
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/relay/pass/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) {
// - handle shape pattern matching
TypeNode* lhs = GetTypeNode(dst);
TypeNode* rhs = GetTypeNode(src);

// do occur check so we don't create self-referencing structure
if (lhs->FindRoot() == rhs->FindRoot()) {
return lhs->resolved_type;
}
if (lhs->resolved_type.as<IncompleteTypeNode>()) {
MergeFromTo(lhs, rhs);
return rhs->resolved_type;
Expand Down
22 changes: 22 additions & 0 deletions tests/cpp/relay_pass_type_infer_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/pass.h>

TEST(Relay, SelfReference) {
using namespace tvm;
auto type_a = relay::TypeVarNode::make("a", relay::TypeVarNode::kType);
auto type_b = relay::TypeVarNode::make("b", relay::TypeVarNode::kType);
auto x = relay::VarNode::make("x", type_a);
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, type_b, Array<relay::TypeVar>{});
auto fx = relay::CallNode::make(f, Array<relay::Expr>{ x });
auto type_fx = relay::InferType(fx, relay::EnvironmentNode::make(Map<relay::GlobalVar, relay::Function>{}));
CHECK_EQ(type_fx->checked_type(), type_a);
}

int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
17 changes: 17 additions & 0 deletions tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@ def test_type_args():
assert sh2[0].value == 1
assert sh2[1].value == 10

def test_self_reference():
"""
Program:
def f(x) {
return x;
}
"""
a = relay.TypeVar("a")
x = relay.var("x", a)
sb = relay.ScopeBuilder()
f = relay.Function([x], x)
fx = relay.Call(f, [x])
assert relay.ir_pass.infer_type(x).checked_type == a
assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a)
assert relay.ir_pass.infer_type(fx).checked_type == a

if __name__ == "__main__":
test_free_expr()
test_dual_op()
Expand All @@ -117,3 +133,4 @@ def test_type_args():
test_tuple()
test_free_expr()
test_type_args()
test_self_reference()

0 comments on commit 4fbb7c8

Please sign in to comment.