Skip to content

Commit

Permalink
Revert D15706021: [jit] Support for type annotations instead of torch…
Browse files Browse the repository at this point in the history
….jit.annotate()

Differential Revision:
D15706021

Original commit changeset: 8bf1459f229d

fbshipit-source-id: 7ae34578560e2dccd0f04af2220445b3999771fe
  • Loading branch information
Will Feng authored and facebook-github-bot committed Jun 11, 2019
1 parent b46e87c commit 7a040f4
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 92 deletions.
15 changes: 5 additions & 10 deletions test/expect/TestScript.test_python_frontend.expect
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
(list
(assign
(variable (ident q))
(None)
(option))
(None))
(assign
(variable (ident q))
(-
Expand All @@ -34,8 +33,7 @@
(variable (ident z))
(ident sigmoid))
(list)
(list)))
(option))
(list))))
(expression statement
(apply
(variable (ident print))
Expand All @@ -44,8 +42,7 @@
(assign
(variable (ident w))
(unary minus
(variable (ident z)))
(option))
(variable (ident z))))
(if
(and
(and
Expand All @@ -58,8 +55,7 @@
(if
(not (variable (ident z)))
(variable (ident x))
(variable (ident y)))
(option)))
(variable (ident y)))))
(list))
(while
(and
Expand All @@ -72,8 +68,7 @@
(list
(assign
(variable (ident q))
(variable (ident x))
(option))))
(variable (ident x)))))
(assert
(eq (const 1) (const 1))
(option (string_literal hello)))
Expand Down
29 changes: 0 additions & 29 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3229,35 +3229,6 @@ def annotate_none_no_optional():
self.checkScript(annotate_none, ())
self.checkScript(annotate_none_no_optional, ())

@unittest.skipIf(PY2, "Python 3 required")
def test_type_annotate_py3(self):
code = dedent("""
import torch
def fn():
a : List[int] = []
b : torch.Tensor = torch.ones(2, 2)
for _ in range(10):
a.append(4)
return a, b
""")

with tempfile.TemporaryDirectory() as tmp_dir:
script_path = os.path.join(tmp_dir, 'script.py')
with open(script_path, 'w') as f:
f.write(code)
fn = get_fn('test_type_annotate_py3', script_path)

self.checkScript(fn, ())

code = dedent("""
def wrong_type():
wrong : List[int] = [0.5]
return wrong
""")

with self.assertRaisesRegex(RuntimeError, "Lists must contain only a single type"):
cu = torch.jit.CompilationUnit(code)

def test_robust_op_resolution(self):
neg = torch.add # misleading name to make sure we resolve by function

Expand Down
6 changes: 1 addition & 5 deletions torch/csrc/jit/script/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1954,12 +1954,8 @@ struct to_ir {
switch (stmt.lhs().kind()) {
case TK_VAR: {
auto v = Var(stmt.lhs());
TypePtr type = nullptr;
if (stmt.type().present()) {
type = typeParser_.parseTypeFromExpr(stmt.type().get());
}
environment_stack->setSugaredVar(
v.range(), v.name().name(), emitSugaredExpr(stmt.rhs(), 1, type));
v.range(), v.name().name(), emitSugaredExpr(stmt.rhs(), 1));
} break;
case TK_TUPLE_LITERAL:
emitTupleAssign(TupleLiteral(stmt.lhs()), stmt.rhs());
Expand Down
39 changes: 16 additions & 23 deletions torch/csrc/jit/script/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ struct ParserImpl {
case TK_NONE: {
auto k = L.cur().kind;
auto r = L.cur().range;
prefix = create_compound(k, r, {});
prefix = c(k, r, {});
L.next();
} break;
case '(': {
Expand Down Expand Up @@ -193,11 +193,11 @@ struct ParserImpl {
case TK_TIMES_EQ:
case TK_DIV_EQ: {
int modifier = L.next().text()[0];
return create_compound(modifier, r, {});
return c(modifier, r, {});
} break;
default: {
L.expect('=');
return create_compound('=', r, {}); // no reduction
return c('=', r, {}); // no reduction
} break;
}
}
Expand All @@ -208,7 +208,7 @@ struct ParserImpl {
auto cond = parseExp();
L.expect(TK_ELSE);
auto false_branch = parseExp(binary_prec);
return create_compound(TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch});
return c(TK_IF_EXPR, range, {cond, std::move(true_branch), false_branch});
}
// parse the longest expression whose binary operators have
// precedence strictly greater than 'precedence'
Expand All @@ -232,7 +232,7 @@ struct ParserImpl {
if (unary_kind == TK_UNARY_MINUS && subexp.kind() == TK_CONST) {
prefix = Const::create(subexp.range(), "-" + Const(subexp).text());
} else {
prefix = create_compound(unary_kind, pos, {subexp});
prefix = c(unary_kind, pos, {subexp});
}
} else {
prefix = parseBaseExp();
Expand Down Expand Up @@ -263,7 +263,7 @@ struct ParserImpl {
continue;
}

prefix = create_compound(kind, pos, {prefix, parseExp(binary_prec)});
prefix = c(kind, pos, {prefix, parseExp(binary_prec)});
}
return Expr(prefix);
}
Expand Down Expand Up @@ -366,17 +366,14 @@ struct ParserImpl {
return Subscript::create(range, Expr(value), subscript_exprs);
}

Maybe<Expr> maybeParseTypeAnnotation() {
TreeRef parseFormalParam(bool kwarg_only) {
auto ident = parseIdent();
TreeRef type;
if (L.nextIf(':')) {
return Maybe<Expr>::create(L.cur().range, parseExp());
type = Maybe<Expr>::create(L.cur().range, parseExp());
} else {
return Maybe<Expr>::create(L.cur().range);
type = Maybe<Expr>::create(L.cur().range);
}
}

TreeRef parseFormalParam(bool kwarg_only) {
auto ident = parseIdent();
TreeRef type = maybeParseTypeAnnotation();
TreeRef def;
if (L.nextIf('=')) {
def = Maybe<Expr>::create(L.cur().range, parseExp());
Expand Down Expand Up @@ -416,12 +413,11 @@ struct ParserImpl {
// alone on a line:
// first[,other,lhs] = rhs
TreeRef parseAssign(const Expr& lhs) {
auto type = maybeParseTypeAnnotation();
auto op = parseAssignmentOp();
auto rhs = parseExpOrExpTuple();
L.expect(TK_NEWLINE);
if (op->kind() == '=') {
return Assign::create(lhs.range(), lhs, Expr(rhs), type);
return Assign::create(lhs.range(), lhs, Expr(rhs));
} else {
// this is an augmented assignment
if (lhs.kind() == TK_TUPLE_LITERAL) {
Expand Down Expand Up @@ -450,7 +446,7 @@ struct ParserImpl {
case TK_RETURN: {
auto range = L.next().range;
Expr value = L.cur().kind != TK_NEWLINE ? parseExpOrExpTuple()
: Expr(create_compound(TK_NONE, range, {}));
: Expr(c(TK_NONE, range, {}));
L.expect(TK_NEWLINE);
return Return::create(range, value);
}
Expand Down Expand Up @@ -537,7 +533,7 @@ struct ParserImpl {
do {
stmts.push_back(parseStmt());
} while (!L.nextIf(TK_DEDENT));
return create_compound(TK_LIST, r, std::move(stmts));
return c(TK_LIST, r, std::move(stmts));
}

Maybe<Expr> parseReturnAnnotation() {
Expand Down Expand Up @@ -619,14 +615,11 @@ struct ParserImpl {

private:
// short helpers to create nodes
TreeRef create_compound(
int kind,
const SourceRange& range,
TreeList&& trees) {
TreeRef c(int kind, const SourceRange& range, TreeList&& trees) {
return Compound::create(kind, range, std::move(trees));
}
TreeRef makeList(const SourceRange& range, TreeList&& trees) {
return create_compound(TK_LIST, range, std::move(trees));
return c(TK_LIST, range, std::move(trees));
}
Lexer L;
SharedParserData& shared;
Expand Down
7 changes: 1 addition & 6 deletions torch/csrc/jit/script/python_tree_views.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,7 @@ void initTreeViewBindings(PyObject* module) {

py::class_<Assign, Stmt>(m, "Assign")
.def(py::init([](const Expr& lhs, const Expr& rhs) {
return Assign::create(
lhs.range(), lhs, rhs, Maybe<Expr>::create(lhs.range()));
}))
.def(py::init([](const Expr& lhs, const Expr& rhs, Expr* type) {
return Assign::create(
lhs.range(), lhs, rhs, wrap_maybe(lhs.range(), type));
return Assign::create(lhs.range(), lhs, rhs);
}));
py::class_<AugAssign, Stmt>(m, "AugAssign")
.def(py::init([](const Expr& lhs, std::string kind_str, const Expr& rhs) {
Expand Down
13 changes: 3 additions & 10 deletions torch/csrc/jit/script/tree_views.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace script {
// | Global(List<Ident> idents) TK_GLOBAL
// -- NB: the only type of Expr's allowed on lhs are Var
// Or a tuple containing Var with an optional terminating Starred
// | Assign(Expr lhs, Expr rhs, Maybe<Expr> type) TK_ASSIGN
// | Assign(Expr lhs, Expr rhs) TK_ASSIGN
// | AugAssign(Expr lhs, AugAssignKind aug_op, Expr rhs) TK_AUG_ASSIGN
// | Return(List<Expr> values) TK_RETURN
// | ExprStmt(List<Expr> expr) TK_EXPR_STMT
Expand Down Expand Up @@ -584,22 +584,15 @@ struct Assign : public Stmt {
static Assign create(
const SourceRange& range,
const Expr& lhs,
const Expr& rhs,
const Maybe<Expr>& type) {
return Assign(Compound::create(TK_ASSIGN, range, {lhs, rhs, type}));
const Expr& rhs) {
return Assign(Compound::create(TK_ASSIGN, range, {lhs, rhs}));
}

Expr lhs() const {
return Expr(subtree(0));
}

Expr rhs() const {
return Expr(subtree(1));
}

Maybe<Expr> type() const {
return Maybe<Expr>(subtree(2));
}
};

struct Return : public Stmt {
Expand Down
11 changes: 2 additions & 9 deletions torch/jit/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from torch._six import PY2
from torch._C._jit_tree_views import *

# Borrowed from cPython implementation
# https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411#
# Borrowed from cPython implementation
# https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411#

_reserved_prefix = '__jit'
_reserved_names = {'print'}
Expand Down Expand Up @@ -282,13 +282,6 @@ def build_Assign(ctx, stmt):
lhs = build_expr(ctx, stmt.targets[0])
return Assign(lhs, rhs)

@staticmethod
def build_AnnAssign(ctx, stmt):
rhs = build_expr(ctx, stmt.value)
lhs = build_expr(ctx, stmt.target)
the_type = build_expr(ctx, stmt.annotation)
return Assign(lhs, rhs, the_type)

@staticmethod
def build_Return(ctx, stmt):
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("return"))
Expand Down

0 comments on commit 7a040f4

Please sign in to comment.