From becf080e4ad43b0cc5a97b063729bb20bee308d3 Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Tue, 1 Oct 2019 16:37:34 -0700 Subject: [PATCH] add dynamic isinstance (#26269) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26269 previously isinstance only worked when we could statically determine if it were true/false. Now we actually can issue an isinstance check in case where it is dependent on the runtime type, e.g. Optional[int] being an instance of int. This is not very useful on its own yet, but with type refinement and allowing Any as an argument type this will allow for python-style "overloaded" functions such that we can remove our __overload__ support. Test Plan: Imported from OSS Differential Revision: D17412853 Pulled By: zdevito fbshipit-source-id: e2c37040f25f6b94ee1676854fceecd22de190ef --- aten/src/ATen/core/interned_strings.h | 7 ++- test/test_jit.py | 26 ++++++-- torch/csrc/jit/ir.cpp | 19 ++++++ torch/csrc/jit/ir.h | 5 ++ torch/csrc/jit/passes/alias_analysis.cpp | 2 + torch/csrc/jit/passes/python_print.cpp | 54 ++++++++++++++++- torch/csrc/jit/register_prim_ops.cpp | 33 +++++++++++ torch/csrc/jit/script/compiler.cpp | 75 +++++++++++++++++++----- 8 files changed, 195 insertions(+), 26 deletions(-) diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 5b438ed5dc..e5837680ff 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -49,7 +49,7 @@ namespace c10 { _(prim, IgnoredPythonOp) \ _(prim, Reverse) \ _(prim, Return) \ - _(prim, ReturnStmt) \ + _(prim, ReturnStmt) \ _(prim, BreakStmt) \ _(prim, ContinueStmt) \ _(prim, Store) \ @@ -93,6 +93,7 @@ namespace c10 { _(prim, enumerate) \ _(prim, range) \ _(prim, rangelist) \ + _(prim, isinstance) \ _(aten, _grad_sum_to_size) \ _(aten, _size_if_not_equal) \ _(aten, _ncf_unsqueeze) \ @@ -221,7 +222,9 @@ namespace c10 { _(attr, beg) \ _(attr, idx) \ _(attr, split) \ - _(attr, slot) + _(attr, slot) \ + _(attr, kinds) \ + _(attr, types) #else #define FORALL_NS_SYMBOLS(_) \ _(namespaces, prim) \ diff --git a/test/test_jit.py b/test/test_jit.py index aad321fa14..c0436be81a 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -6778,11 +6778,12 @@ def test(inp, typ, type_hint): test(inp, typ, type_hint) # test optional isinstance check - with self.assertRaisesRegex(RuntimeError, "Optional isinstance check is not supported"): - @torch.jit.script - def opt_func(x): - # type: (Optional[int]) -> bool - return isinstance(x, int) + @torch.jit.script + def opt_func(x): + # type: (Optional[int]) -> bool + return isinstance(x, int) + self.assertTrue(opt_func(3)) + self.assertFalse(opt_func(None)) def test_dropout_eval(self): class ScriptedConv2d(torch.jit.ScriptModule): @@ -14027,6 +14028,21 @@ def test_non_primitive_types(x): out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0))) self.assertEqual(out, torch.tensor(6.0)) + def test_isinstance_dynamic(self): + @torch.jit.script + def foo(a): + # type: (Optional[List[int]]) -> int + b = 0 + if isinstance(a, (int, (float,), list, str)): + b += 1 + if isinstance(a, (int, str)): + b += 1 + if isinstance(a, List[int]): + b += 1 + return b + self.assertEqual(foo([3, 4]), 2) + self.assertEqual(foo(None), 0) + def test_function_overloads(self): # TODO: pyflakes currently does not compose @overload annotation with other # decorators. This is fixed on master but not on version 2.1.1. diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index a5f7b9381e..4f1b121bbb 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -1517,6 +1517,25 @@ Node* Graph::createLoad(const std::string& name, const TypePtr& type) { return n; } +Node* Graph::createIsInstance( + Value* v, + at::ArrayRef types, + bool is_list, + bool is_tuple) { + auto n = create(prim::isinstance, {v}, /*num_outputs*/ 1); + std::vector kinds; + if (is_list) { + kinds.push_back("list"); + } + if (is_tuple) { + kinds.push_back("tuple"); + } + n->ss_(attr::kinds, std::move(kinds)); + n->tys_(attr::types, types.vec()); + n->output()->setType(BoolType::get()); + return n; +} + Value* Graph::insertFunctionCall( Function* callee, const script::MatchedSchema& matched) { diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 925e3e06d7..695207a6eb 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -1094,6 +1094,11 @@ struct Graph { } TORCH_API Node* createStore(const std::string& name, Value* v); TORCH_API Node* createLoad(const std::string& name, const TypePtr& type); + TORCH_API Node* createIsInstance( + Value* v, + at::ArrayRef types, + bool is_list, + bool is_tuple); TORCH_API Value* insertFunctionCall( Function* callee, diff --git a/torch/csrc/jit/passes/alias_analysis.cpp b/torch/csrc/jit/passes/alias_analysis.cpp index 27c738357e..991f97478f 100644 --- a/torch/csrc/jit/passes/alias_analysis.cpp +++ b/torch/csrc/jit/passes/alias_analysis.cpp @@ -359,6 +359,7 @@ void AliasDb::analyzeImpl(Node* node) { return analyzeConservative(node); case prim::Print: case prim::Uninitialized: + case prim::isinstance: // These ops do nothing return; default: @@ -1246,6 +1247,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) { prim::CallFunction, prim::CallMethod, aten::wait, + prim::isinstance, }; // Operators that should not be used by alias analysis diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index dd5f323714..b70a36e1e9 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -764,9 +764,7 @@ struct PythonPrintPass { registerClassDependencies(containedType); } } - - void printNode(Node* node, bool print_const) { - WithSourceRange guard(&source_range_stack_, node); + void scanTypeDependencies(Node* node) { // Check for class dependencies. If this node inputs or outputs a class // type, we need to add it to our table of dependencies. for (const auto input : node->inputs()) { @@ -775,7 +773,26 @@ struct PythonPrintPass { for (const auto output : node->outputs()) { registerClassDependencies(output->type()); } + for (const auto& name : node->attributeNames()) { + switch (node->kindOf(name)) { + case AttributeKind::ty: + registerClassDependencies(node->ty(name)); + break; + case AttributeKind::tys: + for (const TypePtr& t : node->tys(name)) { + registerClassDependencies(t); + } + break; + default: + // noop + break; + } + } + } + void printNode(Node* node, bool print_const) { + WithSourceRange guard(&source_range_stack_, node); + scanTypeDependencies(node); if (!print_const && node->kind() == prim::Constant) return; splitLongInlines(node->inputs()); @@ -1130,6 +1147,36 @@ struct PythonPrintPass { } stmt << ")"; } break; + case prim::isinstance: { + stmt << "isinstance(" << useOf(node->input()) << ", "; + const auto& types = node->tys(attr::types); + const auto& kinds = node->ss(attr::kinds); + if (types.size() == 1 && kinds.size() == 0) { + stmt << types.at(0)->python_str(); + } else if (kinds.size() == 1 && types.size() == 0) { + stmt << kinds.at(0); + } else { + // check multiple things, e.g. (str, list, int) + stmt << "("; + bool first = true; + for (const TypePtr& typ : types) { + if (!first) { + stmt << ", "; + } + stmt << typ->python_str(); + first = false; + } + for (const std::string& kind : kinds) { + if (!first) { + stmt << ", "; + } + stmt << kind; + first = false; + } + stmt << ")"; + } + stmt << ")"; + } break; default: { printOpName(stmt, node->kind()); const FunctionSchema& schema = node->schema(); @@ -1497,6 +1544,7 @@ bool printerHasSpecialCaseFor(Symbol sym) { prim::GetAttr, prim::SetAttr, prim::CallFunction, + prim::isinstance, }; // WARNING: by adding a value to this set, you are asserting that your diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 26b77e152a..13792e7bec 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -1314,6 +1314,39 @@ RegisterOperators reg( return 0; }; }, + aliasAnalysisSpecialCase()), + Operator( + prim::isinstance, + [](const Node* node) -> Operation { + std::vector types = node->tys(attr::types); + bool is_list = false; + bool is_tuple = false; + for (const std::string& kind : node->ss(attr::kinds)) { + if (kind == "list") { + is_list = true; + } else if (kind == "tuple") { + is_tuple = true; + } else { + TORCH_INTERNAL_ASSERT(false, "unrecognized type kind ", kind); + } + } + return [types, is_list, is_tuple](Stack& stack) { + TypePtr ty = pop(stack).type(); + if ((is_list && ty->kind() == ListType::Kind) || + (is_tuple && ty->kind() == TupleType::Kind)) { + push(stack, true); + return 0; + } + for (const TypePtr& to_check : types) { + if (ty->isSubtypeOf(to_check)) { + push(stack, true); + return 0; + } + } + push(stack, false); + return 0; + }; + }, aliasAnalysisSpecialCase())}); RegisterOperators logging_operators( diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index d6835af45f..245e20fd37 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -140,9 +140,13 @@ struct CondValue { : value_(value), refinements_(std::move(refinements)), static_if_(static_if) {} - CondValue(Graph& g, const SourceRange& loc, bool static_value) + CondValue( + Graph& g, + const SourceRange& loc, + bool static_value, + RefinementSet refinements) : value_(g.insertConstant(static_value, loc)), - refinements_({}), + refinements_(std::move(refinements)), static_if_(static_value) {} Value* value() const { return value_; @@ -1023,12 +1027,12 @@ struct to_ir { // MA, MM, MN, NM, NN, AM -> cannot prove anything statically bool its_is = expr.kind() == TK_IS; if (lhs_none == ALWAYS && rhs_none == ALWAYS) { - return CondValue(*graph, expr.range(), its_is); + return CondValue(*graph, expr.range(), its_is, {}); } else if ( (lhs_none == ALWAYS && rhs_none == NEVER) || (lhs_none == NEVER && rhs_none == ALWAYS)) { // lhs_val/rhs_val with A/M: only emit never_none_branch - return CondValue(*graph, expr.range(), !its_is); + return CondValue(*graph, expr.range(), !its_is, {}); } else { auto kind = getNodeKind(expr.kind(), expr.get()->trees().size()); Value* cond_value = emitBuiltinCall( @@ -1470,6 +1474,40 @@ struct to_ir { TypePtr type = typeParser_.parseTypeFromExpr(classinfo); types.emplace_back(type); } + bool staticallyTrue(const TypePtr& actual_type) { + // is this isinstance check statically true? + if ((list_check && actual_type->kind() == ListType::Kind) || + (tuple_check && actual_type->kind() == TupleType::Kind)) { + return true; + } + for (const TypePtr& typ : types) { + if (actual_type->isSubtypeOf(typ)) { + return true; + } + } + return false; + } + bool maybeOfKind(TypeKind kind, const TypePtr& actual_type) { + if (actual_type->kind() == AnyType::Kind) { + return true; + } + if (auto op = actual_type->cast()) { + return op->getElementType()->kind() == kind; + } + return false; + } + bool staticallyFalse(const TypePtr& actual_type) { + if ((list_check && maybeOfKind(ListType::Kind, actual_type)) || + (tuple_check && maybeOfKind(TupleType::Kind, actual_type))) { + return false; + } + for (const TypePtr& typ : types) { + if (typ->isSubtypeOf(actual_type)) { + return false; + } + } + return true; + } ScriptTypeParser typeParser_; bool list_check = false; bool tuple_check = false; @@ -1478,22 +1516,27 @@ struct to_ir { GatheredTypes gathered(typeParser_); gathered.gather(classinfo); auto val = emitExpr(obj); - if (val->type()->kind() == OptionalType::Kind) { - throw ErrorReport(obj.range()) - << "Optional isinstance check is not supported, " - << "consider use is/is not None instead"; + RefinementSet refinement; + if (gathered.types.size() == 1 && obj.kind() == TK_VAR) { + std::string ident = Var(obj).name().name(); + Refinement isinstance( + std::move(ident), gathered.types.at(0)); + refinement = RefinementSet({isinstance}, {}); } - if ((gathered.list_check && val->type()->kind() == ListType::Kind) || - (gathered.tuple_check && val->type()->kind() == TupleType::Kind)) { - return CondValue(*graph, obj.range(), true); + if (gathered.staticallyTrue(val->type())) { + return CondValue(*graph, obj.range(), true, std::move(refinement)); } - for (const TypePtr& typ : gathered.types) { - if (val->type()->isSubtypeOf(typ)) { - return CondValue(*graph, obj.range(), true); - } + if (gathered.staticallyFalse(val->type())) { + return CondValue(*graph, obj.range(), false, std::move(refinement)); } - return CondValue(*graph, obj.range(), false); + // check maybe true/false at runtime, need an actual op + Value* result = + graph + ->insertNode(graph->createIsInstance( + val, gathered.types, gathered.list_check, gathered.tuple_check)) + ->output(); + return CondValue(result, std::move(refinement), c10::nullopt); } void emitIf(const If& stmt) {