Skip to content

Commit

Permalink
Fix bugs in assignment to optionals (pytorch#24989)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#24989

This fixes the cases where a type annotated with optional cannot
be conditionally assigned to none:

```
x : Optional[int] = 4
if ...:
 x = None
```

Test Plan: Imported from OSS

Differential Revision: D16949314

Pulled By: zdevito

fbshipit-source-id: 7f63d88b30a3f5b024c2a539aa74967c9202af00
  • Loading branch information
zdevito authored and facebook-github-bot committed Aug 22, 2019
1 parent f8611ea commit bb79b61
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 36 deletions.
3 changes: 2 additions & 1 deletion test/expect/TestJit.test_pretty_printer-loop_use_test.expect
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
def loop_use_test(y: Tensor) -> Tuple[Tensor, Tensor]:
x = torch.add(y, 1, 1)
z = torch.add(x, 5, 1)
z0, y0 = z, y
z0 = z
y0 = y
_0 = bool(torch.lt(y, 8))
while _0:
y1 = torch.add_(y0, 1, 1)
Expand Down
4 changes: 3 additions & 1 deletion test/expect/TestJit.test_pretty_printer-while_if_test.expect
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
def while_if_test(a: Tensor,
b: Tensor) -> Tensor:
a0, c, b0 = a, 0, b
a0 = a
c = 0
b0 = b
_0 = bool(torch.lt(a, 10))
while _0:
a1 = torch.add(a0, 1, 1)
Expand Down
3 changes: 2 additions & 1 deletion test/expect/TestJit.test_pretty_printer-while_test.expect
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
def while_test(a: Tensor,
i: Tensor) -> Tensor:
a0, i0 = a, i
a0 = a
i0 = i
_0 = bool(torch.lt(i, 3))
while _0:
a1 = torch.mul_(a0, a0)
Expand Down
3 changes: 3 additions & 0 deletions test/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.utils import cpp_extension
from common_utils import TEST_WITH_ROCM, shell
import torch.distributed as dist
from torch._six import PY2

TESTS = [
'autograd',
Expand Down Expand Up @@ -60,6 +61,8 @@
'namedtensor',
'jit_disabled',
]
if not PY2:
TESTS.append('jit_py3')

WINDOWS_BLACKLIST = [
'distributed',
Expand Down
21 changes: 19 additions & 2 deletions test/test_jit_py3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from jit_utils import JitTestCase
from torch.testing import FileCheck
from typing import NamedTuple, List, Optional

import unittest
import torch


Expand Down Expand Up @@ -148,6 +148,7 @@ def foo():
tup = MyCoolNamedTuple(c=[1, 2, 3], b=3.5, a=9) # noqa
return tup

@unittest.skipIf(True, "broken while these tests were not in CI")
def test_named_tuple_serialization(self):
class MyCoolNamedTuple(NamedTuple):
a : int
Expand Down Expand Up @@ -175,10 +176,12 @@ def fn():
a : List[int] = []
b : torch.Tensor = torch.ones(2, 2)
c : Optional[torch.Tensor] = None
d : Optional[torch.Tensor] = torch.ones(3, 4)
for _ in range(10):
a.append(4)
c = torch.ones(2, 2)
return a, b, c
d = None
return a, b, c, d

self.checkScript(fn, ())

Expand All @@ -193,6 +196,20 @@ def test_parser_bug(self):
def parser_bug(o: Optional[torch.Tensor]):
pass

def test_mismatched_annotation(self):
with self.assertRaisesRegex(RuntimeError, 'annotated with type'):
@torch.jit.script
def foo():
x : str = 4
return x

def test_reannotate(self):
with self.assertRaisesRegex(RuntimeError, 'declare and annotate'):
@torch.jit.script
def foo():
x = 5
if True:
x : Optional[int] = 7


if __name__ == '__main__':
Expand Down
30 changes: 24 additions & 6 deletions torch/csrc/jit/passes/python_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,12 +592,30 @@ struct PythonPrintPass {
}

void printAssignment(at::ArrayRef<Value*> lhs, at::ArrayRef<Value*> rhs) {
if (lhs.size() > 0) {
if (lhs.size() == 0) {
return;
}
indent();
printValueList(body_, lhs);
body_ << " = ";
printValueList(body_, rhs);
body_ << "\n";
}

bool requiresAnnotation(Value* lhs, Value* rhs) {
return *lhs->type() != *rhs->type();
}

void printAnnotatedAssignment(
at::ArrayRef<Value*> lhs,
at::ArrayRef<Value*> rhs) {
for (size_t i = 0; i < lhs.size(); ++i) {
indent();
printValueList(body_, lhs);
body_ << " = ";
printValueList(body_, rhs);
body_ << "\n";
body_ << useOf(lhs[i]);
if (requiresAnnotation(lhs[i], rhs[i])) {
body_ << ": " << lhs[i]->type()->python_str();
}
body_ << " = " << useOf(rhs[i]) << "\n";
}
}

Expand Down Expand Up @@ -643,7 +661,7 @@ struct PythonPrintPass {
});

// Print initial assignments of loop node outputs = loop node inputs
printAssignment(stmt.carriedOutputs(), stmt.carriedInputs());
printAnnotatedAssignment(stmt.carriedOutputs(), stmt.carriedInputs());

assignValuesToTheirUniqueNames(stmt.currentTripCount());
// Loop header
Expand Down
77 changes: 58 additions & 19 deletions torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,17 @@ struct Environment {
return std::make_shared<SimpleValue>(load->output());
}

void insertStore(const std::string& name, const SourceRange& loc, Value* v) {
// note: type is not always the same as v->type(), e.g.
// type: Optional[Tensor]
// v->type(): Tensor
void insertStore(
const std::string& name,
const SourceRange& loc,
Value* v,
TypePtr type) {
auto g = b->owningGraph();
auto store = g->insertNode(g->createStore(name, v))->setSourceRange(loc);
type_table[name] = store->input()->type();
g->insertNode(g->createStore(name, v))->setSourceRange(loc);
type_table[name] = type;
}

SugaredValuePtr findInThisFrame(const std::string& name) {
Expand Down Expand Up @@ -269,13 +276,18 @@ struct Environment {
}

void setVar(const SourceRange& loc, const std::string& name, Value* value) {
setSugaredVar(loc, name, std::make_shared<SimpleValue>(value));
setSugaredVar(
loc,
name,
std::make_shared<SimpleValue>(value),
/*annotated_type=*/nullptr);
}

void setSugaredVar(
const SourceRange& loc,
const std::string& name,
SugaredValuePtr value) {
SugaredValuePtr value,
TypePtr annotated_type) {
Value* as_simple_value = asSimple(value);
if (as_simple_value && !as_simple_value->hasDebugName() &&
meaningfulName(name) &&
Expand All @@ -293,6 +305,11 @@ struct Environment {
// requires 'a' to be first-class in the graph since its value depends on
// control flow
if (auto parent = findInParentFrame(name)) {
if (annotated_type) {
throw ErrorReport(loc)
<< "Attempting to declare and annotate the type of variable '"
<< name << "' but it is already defined in an outer block";
}
if (!as_simple_value) {
throw ErrorReport(loc)
<< "Cannot re-assign '" << name << "' to a value of type "
Expand All @@ -306,8 +323,15 @@ struct Environment {
<< value->kind() << " and " << name
<< " is not a first-class value. Only reassignments to first-class values are allowed";
}
if (!as_simple_value->type()->isSubtypeOf(
unshapedType(simple_parent->type()))) {

auto parent_type = unshapedType(simple_parent->type());
as_simple_value = tryConvertToType(
loc,
*b->owningGraph(),
parent_type,
as_simple_value,
/*allow_conversions=*/true);
if (!as_simple_value->type()->isSubtypeOf(parent_type)) {
auto error = ErrorReport(loc);
error << "Variable '" << name << "' previously has type "
<< simple_parent->type()->python_str()
Expand All @@ -326,7 +350,17 @@ struct Environment {
}
}
if (as_simple_value) {
insertStore(name, loc, std::move(as_simple_value));
if (!annotated_type) {
annotated_type = as_simple_value->type();
}
if (!as_simple_value->type()->isSubtypeOf(annotated_type)) {
throw ErrorReport(loc)
<< "Variable '" << name << "' is annotated with type "
<< annotated_type->python_str()
<< " but is being assigned to a value of type "
<< as_simple_value->type()->python_str();
}
insertStore(name, loc, std::move(as_simple_value), annotated_type);
} else {
value_table[name] = std::move(value);
}
Expand Down Expand Up @@ -772,7 +806,10 @@ struct to_ir {
const auto& name = (*it).ident().name();
Value* new_input = block->addInput()->setDebugName(name);
environment_stack->setSugaredVar(
(*it).ident().range(), name, self->makeSugared(new_input));
(*it).ident().range(),
name,
self->makeSugared(new_input),
/*annotated_type=*/nullptr);
arguments.emplace_back(name, new_input->type());
++it;
}
Expand Down Expand Up @@ -867,7 +904,10 @@ struct to_ir {
};
auto closure_value = emitClosure(emit_body);
environment_stack->setSugaredVar(
def.name().range(), def.name().name(), closure_value);
def.name().range(),
def.name().name(),
closure_value,
/*annotated_type=*/nullptr);
}

void emitBreak(const Break& stmt) {
Expand Down Expand Up @@ -947,13 +987,6 @@ struct to_ir {
case TK_AUG_ASSIGN:
emitAugAssignment(AugAssign(stmt));
break;
case TK_GLOBAL:
for (auto ident : Global(stmt).names()) {
const auto& name = Ident(ident).name();
environment_stack->setVar(
ident.range(), name, graph->addInput(name));
}
break;
case TK_EXPR_STMT: {
auto expr = ExprStmt(stmt).expr();
emitSugaredExpr(expr, 0);
Expand Down Expand Up @@ -1960,7 +1993,10 @@ struct to_ir {
break;
case TK_VAR:
environment_stack->setSugaredVar(
assignee.range(), Var(assignee).name().name(), outputs.at(i));
assignee.range(),
Var(assignee).name().name(),
outputs.at(i),
/*annotated_type=*/nullptr);
i++;
break;
case TK_STARRED: {
Expand Down Expand Up @@ -2041,7 +2077,10 @@ struct to_ir {
type = typeParser_.parseTypeFromExpr(stmt.type().get());
}
environment_stack->setSugaredVar(
v.range(), v.name().name(), emitSugaredExpr(rhs, 1, type));
v.range(),
v.name().name(),
emitSugaredExpr(rhs, 1, type),
/*annotated_type=*/type);
} break;
case TK_TUPLE_LITERAL:
emitTupleAssign(TupleLiteral(stmt.lhs()), rhs);
Expand Down
17 changes: 11 additions & 6 deletions torch/csrc/jit/script/convert_to_ssa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,25 @@ struct ControlFlowLoadStores {
auto loop_vars = addControlFlowLoadStores(body_block);

for (const auto& name : loop_vars->definedVariables()) {
// we require that the variable is defined outside the loop to be emitted,
// and we do not refine the type of the parent variable since the loop may
// not be entered.
// if the variable local to the loop body, then
// we do not need a loop carried variable for it
auto parent_type = environment_stack->findInAnyFrame(name);
if (!parent_type) {
continue;
}

// since the loop may execute 0 or many times, the output types
// of the loop and the input loop carried dependencies are conservatively
// the union of the output of the body and the input to the loop
auto block_type = loop_vars->findInThisFrame(name);
auto unified_type = unifyTypes(parent_type, block_type).value();

// Insert a store at the beginning of the loop block, so that all
// loads of the variable will use the loop carried value
addNodeInput(n, parent_type, name);
addBlockInput(body_block, parent_type, name);
addBlockOutput(body_block, parent_type, name);
addNodeOutput(n, parent_type, name);
addBlockInput(body_block, unified_type, name);
addBlockOutput(body_block, block_type, name);
addNodeOutput(n, unified_type, name);
}
}

Expand Down

0 comments on commit bb79b61

Please sign in to comment.