Skip to content

Commit

Permalink
add dynamic isinstance (#26269)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#26269

previously isinstance only worked when we could statically determine
if it were true/false. Now we actually can issue an isinstance check
in case where it is dependent on the runtime type, e.g. Optional[int]
being an instance of int. This is not very useful on its own yet,
but with type refinement and allowing Any as an argument type this will
allow for python-style "overloaded" functions such that we can
remove our __overload__ support.

Test Plan: Imported from OSS

Differential Revision: D17412853

Pulled By: zdevito

fbshipit-source-id: e2c37040f25f6b94ee1676854fceecd22de190ef
  • Loading branch information
zdevito authored and facebook-github-bot committed Oct 1, 2019
1 parent 8a38a53 commit becf080
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 26 deletions.
7 changes: 5 additions & 2 deletions aten/src/ATen/core/interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace c10 {
_(prim, IgnoredPythonOp) \
_(prim, Reverse) \
_(prim, Return) \
_(prim, ReturnStmt) \
_(prim, ReturnStmt) \
_(prim, BreakStmt) \
_(prim, ContinueStmt) \
_(prim, Store) \
Expand Down Expand Up @@ -93,6 +93,7 @@ namespace c10 {
_(prim, enumerate) \
_(prim, range) \
_(prim, rangelist) \
_(prim, isinstance) \
_(aten, _grad_sum_to_size) \
_(aten, _size_if_not_equal) \
_(aten, _ncf_unsqueeze) \
Expand Down Expand Up @@ -221,7 +222,9 @@ namespace c10 {
_(attr, beg) \
_(attr, idx) \
_(attr, split) \
_(attr, slot)
_(attr, slot) \
_(attr, kinds) \
_(attr, types)
#else
#define FORALL_NS_SYMBOLS(_) \
_(namespaces, prim) \
Expand Down
26 changes: 21 additions & 5 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6778,11 +6778,12 @@ def test(inp, typ, type_hint):
test(inp, typ, type_hint)

# test optional isinstance check
with self.assertRaisesRegex(RuntimeError, "Optional isinstance check is not supported"):
@torch.jit.script
def opt_func(x):
# type: (Optional[int]) -> bool
return isinstance(x, int)
@torch.jit.script
def opt_func(x):
# type: (Optional[int]) -> bool
return isinstance(x, int)
self.assertTrue(opt_func(3))
self.assertFalse(opt_func(None))

def test_dropout_eval(self):
class ScriptedConv2d(torch.jit.ScriptModule):
Expand Down Expand Up @@ -14027,6 +14028,21 @@ def test_non_primitive_types(x):
out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0)))
self.assertEqual(out, torch.tensor(6.0))

def test_isinstance_dynamic(self):
@torch.jit.script
def foo(a):
# type: (Optional[List[int]]) -> int
b = 0
if isinstance(a, (int, (float,), list, str)):
b += 1
if isinstance(a, (int, str)):
b += 1
if isinstance(a, List[int]):
b += 1
return b
self.assertEqual(foo([3, 4]), 2)
self.assertEqual(foo(None), 0)

def test_function_overloads(self):
# TODO: pyflakes currently does not compose @overload annotation with other
# decorators. This is fixed on master but not on version 2.1.1.
Expand Down
19 changes: 19 additions & 0 deletions torch/csrc/jit/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,25 @@ Node* Graph::createLoad(const std::string& name, const TypePtr& type) {
return n;
}

Node* Graph::createIsInstance(
Value* v,
at::ArrayRef<TypePtr> types,
bool is_list,
bool is_tuple) {
auto n = create(prim::isinstance, {v}, /*num_outputs*/ 1);
std::vector<std::string> kinds;
if (is_list) {
kinds.push_back("list");
}
if (is_tuple) {
kinds.push_back("tuple");
}
n->ss_(attr::kinds, std::move(kinds));
n->tys_(attr::types, types.vec());
n->output()->setType(BoolType::get());
return n;
}

Value* Graph::insertFunctionCall(
Function* callee,
const script::MatchedSchema& matched) {
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,11 @@ struct Graph {
}
TORCH_API Node* createStore(const std::string& name, Value* v);
TORCH_API Node* createLoad(const std::string& name, const TypePtr& type);
TORCH_API Node* createIsInstance(
Value* v,
at::ArrayRef<TypePtr> types,
bool is_list,
bool is_tuple);

TORCH_API Value* insertFunctionCall(
Function* callee,
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/passes/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ void AliasDb::analyzeImpl(Node* node) {
return analyzeConservative(node);
case prim::Print:
case prim::Uninitialized:
case prim::isinstance:
// These ops do nothing
return;
default:
Expand Down Expand Up @@ -1246,6 +1247,7 @@ bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
prim::CallFunction,
prim::CallMethod,
aten::wait,
prim::isinstance,
};

// Operators that should not be used by alias analysis
Expand Down
54 changes: 51 additions & 3 deletions torch/csrc/jit/passes/python_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -764,9 +764,7 @@ struct PythonPrintPass {
registerClassDependencies(containedType);
}
}

void printNode(Node* node, bool print_const) {
WithSourceRange guard(&source_range_stack_, node);
void scanTypeDependencies(Node* node) {
// Check for class dependencies. If this node inputs or outputs a class
// type, we need to add it to our table of dependencies.
for (const auto input : node->inputs()) {
Expand All @@ -775,7 +773,26 @@ struct PythonPrintPass {
for (const auto output : node->outputs()) {
registerClassDependencies(output->type());
}
for (const auto& name : node->attributeNames()) {
switch (node->kindOf(name)) {
case AttributeKind::ty:
registerClassDependencies(node->ty(name));
break;
case AttributeKind::tys:
for (const TypePtr& t : node->tys(name)) {
registerClassDependencies(t);
}
break;
default:
// noop
break;
}
}
}

void printNode(Node* node, bool print_const) {
WithSourceRange guard(&source_range_stack_, node);
scanTypeDependencies(node);
if (!print_const && node->kind() == prim::Constant)
return;
splitLongInlines(node->inputs());
Expand Down Expand Up @@ -1130,6 +1147,36 @@ struct PythonPrintPass {
}
stmt << ")";
} break;
case prim::isinstance: {
stmt << "isinstance(" << useOf(node->input()) << ", ";
const auto& types = node->tys(attr::types);
const auto& kinds = node->ss(attr::kinds);
if (types.size() == 1 && kinds.size() == 0) {
stmt << types.at(0)->python_str();
} else if (kinds.size() == 1 && types.size() == 0) {
stmt << kinds.at(0);
} else {
// check multiple things, e.g. (str, list, int)
stmt << "(";
bool first = true;
for (const TypePtr& typ : types) {
if (!first) {
stmt << ", ";
}
stmt << typ->python_str();
first = false;
}
for (const std::string& kind : kinds) {
if (!first) {
stmt << ", ";
}
stmt << kind;
first = false;
}
stmt << ")";
}
stmt << ")";
} break;
default: {
printOpName(stmt, node->kind());
const FunctionSchema& schema = node->schema();
Expand Down Expand Up @@ -1497,6 +1544,7 @@ bool printerHasSpecialCaseFor(Symbol sym) {
prim::GetAttr,
prim::SetAttr,
prim::CallFunction,
prim::isinstance,
};

// WARNING: by adding a value to this set, you are asserting that your
Expand Down
33 changes: 33 additions & 0 deletions torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,39 @@ RegisterOperators reg(
return 0;
};
},
aliasAnalysisSpecialCase()),
Operator(
prim::isinstance,
[](const Node* node) -> Operation {
std::vector<TypePtr> types = node->tys(attr::types);
bool is_list = false;
bool is_tuple = false;
for (const std::string& kind : node->ss(attr::kinds)) {
if (kind == "list") {
is_list = true;
} else if (kind == "tuple") {
is_tuple = true;
} else {
TORCH_INTERNAL_ASSERT(false, "unrecognized type kind ", kind);
}
}
return [types, is_list, is_tuple](Stack& stack) {
TypePtr ty = pop(stack).type();
if ((is_list && ty->kind() == ListType::Kind) ||
(is_tuple && ty->kind() == TupleType::Kind)) {
push(stack, true);
return 0;
}
for (const TypePtr& to_check : types) {
if (ty->isSubtypeOf(to_check)) {
push(stack, true);
return 0;
}
}
push(stack, false);
return 0;
};
},
aliasAnalysisSpecialCase())});

RegisterOperators logging_operators(
Expand Down
75 changes: 59 additions & 16 deletions torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,13 @@ struct CondValue {
: value_(value),
refinements_(std::move(refinements)),
static_if_(static_if) {}
CondValue(Graph& g, const SourceRange& loc, bool static_value)
CondValue(
Graph& g,
const SourceRange& loc,
bool static_value,
RefinementSet refinements)
: value_(g.insertConstant(static_value, loc)),
refinements_({}),
refinements_(std::move(refinements)),
static_if_(static_value) {}
Value* value() const {
return value_;
Expand Down Expand Up @@ -1023,12 +1027,12 @@ struct to_ir {
// MA, MM, MN, NM, NN, AM -> cannot prove anything statically
bool its_is = expr.kind() == TK_IS;
if (lhs_none == ALWAYS && rhs_none == ALWAYS) {
return CondValue(*graph, expr.range(), its_is);
return CondValue(*graph, expr.range(), its_is, {});
} else if (
(lhs_none == ALWAYS && rhs_none == NEVER) ||
(lhs_none == NEVER && rhs_none == ALWAYS)) {
// lhs_val/rhs_val with A/M: only emit never_none_branch
return CondValue(*graph, expr.range(), !its_is);
return CondValue(*graph, expr.range(), !its_is, {});
} else {
auto kind = getNodeKind(expr.kind(), expr.get()->trees().size());
Value* cond_value = emitBuiltinCall(
Expand Down Expand Up @@ -1470,6 +1474,40 @@ struct to_ir {
TypePtr type = typeParser_.parseTypeFromExpr(classinfo);
types.emplace_back(type);
}
bool staticallyTrue(const TypePtr& actual_type) {
// is this isinstance check statically true?
if ((list_check && actual_type->kind() == ListType::Kind) ||
(tuple_check && actual_type->kind() == TupleType::Kind)) {
return true;
}
for (const TypePtr& typ : types) {
if (actual_type->isSubtypeOf(typ)) {
return true;
}
}
return false;
}
bool maybeOfKind(TypeKind kind, const TypePtr& actual_type) {
if (actual_type->kind() == AnyType::Kind) {
return true;
}
if (auto op = actual_type->cast<OptionalType>()) {
return op->getElementType()->kind() == kind;
}
return false;
}
bool staticallyFalse(const TypePtr& actual_type) {
if ((list_check && maybeOfKind(ListType::Kind, actual_type)) ||
(tuple_check && maybeOfKind(TupleType::Kind, actual_type))) {
return false;
}
for (const TypePtr& typ : types) {
if (typ->isSubtypeOf(actual_type)) {
return false;
}
}
return true;
}
ScriptTypeParser typeParser_;
bool list_check = false;
bool tuple_check = false;
Expand All @@ -1478,22 +1516,27 @@ struct to_ir {
GatheredTypes gathered(typeParser_);
gathered.gather(classinfo);
auto val = emitExpr(obj);
if (val->type()->kind() == OptionalType::Kind) {
throw ErrorReport(obj.range())
<< "Optional isinstance check is not supported, "
<< "consider use is/is not None instead";
RefinementSet refinement;
if (gathered.types.size() == 1 && obj.kind() == TK_VAR) {
std::string ident = Var(obj).name().name();
Refinement isinstance(
std::move(ident), gathered.types.at(0));
refinement = RefinementSet({isinstance}, {});
}

if ((gathered.list_check && val->type()->kind() == ListType::Kind) ||
(gathered.tuple_check && val->type()->kind() == TupleType::Kind)) {
return CondValue(*graph, obj.range(), true);
if (gathered.staticallyTrue(val->type())) {
return CondValue(*graph, obj.range(), true, std::move(refinement));
}
for (const TypePtr& typ : gathered.types) {
if (val->type()->isSubtypeOf(typ)) {
return CondValue(*graph, obj.range(), true);
}
if (gathered.staticallyFalse(val->type())) {
return CondValue(*graph, obj.range(), false, std::move(refinement));
}
return CondValue(*graph, obj.range(), false);
// check maybe true/false at runtime, need an actual op
Value* result =
graph
->insertNode(graph->createIsInstance(
val, gathered.types, gathered.list_check, gathered.tuple_check))
->output();
return CondValue(result, std::move(refinement), c10::nullopt);
}

void emitIf(const If& stmt) {
Expand Down

0 comments on commit becf080

Please sign in to comment.