diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 5792adeec6..498f4bc355 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -782,6 +782,19 @@ struct CAFFE2_API NamedType : public Type { c10::optional name_; }; +// Any should never appear in a named type like a class, namedtuple or +// interface. If it does, then dynamic type information will be lost in the +// Pickler, leading to hard-to-track-down bugs that will only occur +// after saving or loading a model. This is because we rely on the +// static types in named types to reconstruct type tags of loaded +// values. Lifting this restriction requires solving the serialization +// problem first. +CAFFE2_API void checkNoAny( + const Type& base, + const char* what, + const std::string& attrname, + const TypePtr& attrtype); + struct TupleType; using TupleTypePtr = std::shared_ptr; using NameList = std::vector; diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 1bb7959252..97cc9111ef 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -583,6 +583,11 @@ TupleType::TupleType( std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) { return v->hasFreeVariables(); }); + if (schema_) { + for (const Argument& arg : schema_->arguments()) { + checkNoAny(*this, "attribute", arg.name(), arg.type()); + } + } } bool TupleType::isSubtypeOfExt(const TypePtr rhs_, std::ostream* why_not) const { @@ -718,4 +723,34 @@ InterfaceType::InterfaceType(QualifiedName name) InterfaceType::~InterfaceType() = default; + +static bool containsAny(const TypePtr& type) { + std::vector to_scan = { type }; + while (!to_scan.empty()) { + TypePtr typ = to_scan.back(); + to_scan.pop_back(); + if (typ->kind() == AnyType::Kind) { + return true; + } + for (const TypePtr& sub : typ->containedTypes()) { + to_scan.emplace_back(sub); + } + } + return false; +} + +void checkNoAny(const Type& base, const char* what, const std::string& attrname, const TypePtr& attrtype) { + TORCH_CHECK( + !containsAny(attrtype), + "attempting to add ", + what, + " '", + attrname, + "' of type ", + attrtype->python_str(), + " to '", + base.python_str(), + "' but it contains an Any type. Any types cannot be members of modules, classes, or named tuples."); +} + } // namespace c10 diff --git a/test/cpp/jit/test_autodiff.cpp b/test/cpp/jit/test_autodiff.cpp index 6bde65e194..b9abd5f3e0 100644 --- a/test/cpp/jit/test_autodiff.cpp +++ b/test/cpp/jit/test_autodiff.cpp @@ -66,8 +66,7 @@ std::shared_ptr trace( auto input_typeptr = TupleType::create(std::move(input_types)); std::shared_ptr state; Stack trace_stack_in; - std::tie(state, trace_stack_in) = - tracer::enter(tracer::TypedStack(input_vars, input_typeptr)); + std::tie(state, trace_stack_in) = tracer::enter(input_vars); variable_list trace_vars_in = fmap( trace_stack_in, [](const IValue& v) { return Variable(v.toTensor()); }); auto trace_vars_out = test(trace_vars_in); diff --git a/test/test_jit.py b/test/test_jit.py index 60f79f531f..d32ffe66ff 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -14,7 +14,7 @@ from torch._C import TensorType, BoolType, parse_ir, _propagate_shapes from torch._six import inf, PY2, PY37, StringIO from torch.autograd import Variable, Function -from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401 +from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401 from torch.jit.frontend import NotSupportedError from torch.onnx import OperatorExportTypes from torch.testing import FileCheck @@ -6987,6 +6987,27 @@ def foo2(a, b): self.assertEqual(foo2(None, 4), 0) self.assertEqual(foo2(4, None), 0) + @torch.jit.script + def any_refinement(a, b): + # type: (Any, Any) -> int + if isinstance(a, int) and isinstance(b, int): + return a + b + return 0 + + self.assertEqual(any_refinement(3, 4), 7) + self.assertEqual(any_refinement(3, "hi"), 0) + + def test_any_in_class_fails(self): + with self.assertRaisesRegex(RuntimeError, "contains an Any"): + @torch.jit.script + class Foo(object): + def __init__(self, a): + # type: (Tuple[int,Any]) -> None + self.a = a + + def hi(self): + pass + def test_isinstance(self): # test isinstance operator for static type checking template = dedent(''' diff --git a/test/test_jit_py3.py b/test/test_jit_py3.py index 41d28d478e..7b39821eea 100644 --- a/test/test_jit_py3.py +++ b/test/test_jit_py3.py @@ -1,7 +1,7 @@ from common_utils import run_tests from jit_utils import JitTestCase from torch.testing import FileCheck -from typing import NamedTuple, List, Optional +from typing import NamedTuple, List, Optional, Any import unittest import sys import torch @@ -230,5 +230,17 @@ def foo(): x : Optional[int] = 7 + def test_any_in_class_fails(self): + class MyCoolNamedTuple(NamedTuple): + a : Any + b : float + c : List[int] + with self.assertRaisesRegex(RuntimeError, "contains an Any"): + @torch.jit.script + def foo(): + return MyCoolNamedTuple(4, 5.5, [3]) + print(foo.graph) + + if __name__ == '__main__': run_tests() diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 33ea9531b6..9c4c65bcd5 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -521,7 +521,7 @@ def _get_overloaded_methods(method, mod_class): try: import typing - from typing import Tuple, List, Dict, Optional + from typing import Tuple, List, Dict, Optional, Any def is_tuple(ann): # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule @@ -607,10 +607,14 @@ class OptionalCls(object): def __getitem__(self, types): return OptionalInstance(types) + class AnyCls(object): + pass + Tuple = TupleCls() # noqa: T484 List = ListCls() # noqa: T484 Dict = DictCls() # noqa: T484 Optional = DictCls() # noqa: T484 + Any = AnyCls() # noqa: T484 def is_tuple(ann): return isinstance(ann, TupleInstance) diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 3dfd066fcd..71eb5f0305 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -243,7 +243,7 @@ void initJITBindings(PyObject* module) { Stack stack; stack.reserve(inputs.size()); // captures? for (auto& obj : inputs) { - stack.push_back(toIValue(obj)); + stack.push_back(toTypeInferredIValue(obj)); } ArgumentSpec spec = arg_spec_creator.create(with_grad, stack); arg_spec_creator.specializeTypes(*graph, spec); @@ -314,7 +314,7 @@ void initJITBindings(PyObject* module) { [](std::shared_ptr g, py::tuple args, const std::string& unqualified_op_name) { - auto stack = toStack(args); + auto stack = toTraceableStack(args); checkAliasAnnotation(g, std::move(stack), unqualified_op_name); }) .def( @@ -535,7 +535,7 @@ void initJITBindings(PyObject* module) { // Convert the output of the user-supplied funciton to IValue. The type // information of this IValue is used both to record the correct type in // the trace. - output_ivalue = toIValue(py_func_output); + output_ivalue = toTypeInferredIValue(py_func_output); Value* out_val = jit::tracer::getValueTrace(output_ivalue); body_block->registerOutput(out_val); node_output = @@ -556,7 +556,7 @@ void initJITBindings(PyObject* module) { return PythonFutureWrapper(retval); } else { - auto result = toIValue(f(*args_tup)); + auto result = toTypeInferredIValue(f(*args_tup)); auto retval = c10::make_intrusive(result.type()); retval->markCompleted(std::move(result)); return PythonFutureWrapper(retval); diff --git a/torch/csrc/jit/pybind.h b/torch/csrc/jit/pybind.h index e1b7c3b540..a6cefa13b5 100644 --- a/torch/csrc/jit/pybind.h +++ b/torch/csrc/jit/pybind.h @@ -27,7 +27,7 @@ struct type_caster { bool load(handle src, bool) { try { - value = torch::jit::toIValue(src); + value = torch::jit::toTypeInferredIValue(src); return true; } catch (std::exception& e) { return false; diff --git a/torch/csrc/jit/pybind_utils.h b/torch/csrc/jit/pybind_utils.h index 58b87c738b..14d33b1213 100644 --- a/torch/csrc/jit/pybind_utils.h +++ b/torch/csrc/jit/pybind_utils.h @@ -44,8 +44,6 @@ namespace jit { // that is confusing to display to the end user since it always reports // locations in libtorch code rather than user code. -using tracer::TypedStack; - inline std::shared_ptr get_python_cu() { return py::module::import("torch.jit") .attr("_python_cu") @@ -287,37 +285,24 @@ inline bool isTraceableType(TypePtr type) { return false; } -inline TypedIValue toTraceableIValue(py::handle input) { +inline IValue toTypeInferredIValue(py::handle input) { auto match = tryToInferType(input); if (!match.success()) { AT_ERROR( "Tracer cannot infer type of ", py::str(input), "\n:", match.reason()); } - auto type = match.type(); - - if (isTraceableType(type)) { - return TypedIValue(toIValue(input, type), type); - } + return toIValue(input, match.type()); +} - AT_ERROR( +inline Stack toTraceableStack(const py::tuple& inputs) { + auto info = toTypeInferredIValue(inputs); + AT_CHECK( + isTraceableType(info.type()), "Type '", - type->python_str(), + info.type()->python_str(), "' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and" " Tuples of Tensors can be traced"); -} - -inline IValue toIValue(py::handle input) { - return toTraceableIValue(input).ivalue(); -} - -inline Stack toStack(const py::tuple& inputs) { - return toIValue(inputs).toTuple()->elements(); -} - -inline TypedStack toTypedStack(const py::tuple& inputs) { - auto info = toTraceableIValue(inputs); - return TypedStack( - info.ivalue().toTuple()->elements(), info.type()->expect()); + return info.toTuple()->elements(); } inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) { @@ -545,7 +530,7 @@ inline IValue toIValue( case TypeKind::CapsuleType: AT_ERROR("Capsule Values aren't supported"); case TypeKind::AnyType: - AT_ERROR("AnyType Values aren't supported"); + return toTypeInferredIValue(obj); } AT_ERROR( "Missing cases in toIValue for type: ", diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp index 921e24f93e..e1b2dac196 100644 --- a/torch/csrc/jit/python_ir.cpp +++ b/torch/csrc/jit/python_ir.cpp @@ -678,6 +678,8 @@ void initPythonIRBindings(PyObject* module_) { return self->isSubtypeOf(other); }); + py::class_>(m, "AnyType") + .def_static("get", &AnyType::get); py::class_>(m, "NumberType") .def_static("get", &NumberType::get); py::class_>(m, "IntType") diff --git a/torch/csrc/jit/python_tracer.cpp b/torch/csrc/jit/python_tracer.cpp index a05e028621..bd78d85a18 100644 --- a/torch/csrc/jit/python_tracer.cpp +++ b/torch/csrc/jit/python_tracer.cpp @@ -51,7 +51,7 @@ SourceRange getPythonInterpreterSourceRange() { std::shared_ptr createGraphByTracing( const py::function& func, - TypedStack trace_inputs, + Stack trace_inputs, const py::function& var_name_lookup_fn, bool force_outplace, script::Module* self) { @@ -78,7 +78,7 @@ std::shared_ptr createGraphByTracing( "The traced function didn't return any values! Side-effects are not " "captured in traces, so it would be a no-op."); } - tracer::exit({toIValue(out)}); + tracer::exit({toTypeInferredIValue(out)}); if (script::getInlineEverythingMode()) { Inline(*graph); } @@ -161,10 +161,10 @@ void initPythonTracerBindings(PyObject* module) { m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); }); m.def("_tracer_enter", [](py::args trace_inputs) { - return tracer::enter(toTypedStack(trace_inputs)); + return tracer::enter(toTraceableStack(trace_inputs)); }); m.def("_tracer_exit", [](py::tuple var_outputs) { - tracer::exit(toStack(var_outputs)); + tracer::exit(toTraceableStack(var_outputs)); }); m.def("_tracer_abandon", []() { tracer::abandon(); }); m.def("_get_tracing_state", []() { return getTracingState(); }); diff --git a/torch/csrc/jit/python_tracer.h b/torch/csrc/jit/python_tracer.h index 24be865d2b..d865017ce1 100644 --- a/torch/csrc/jit/python_tracer.h +++ b/torch/csrc/jit/python_tracer.h @@ -29,7 +29,7 @@ Node* preRecordPythonTrace( std::shared_ptr createGraphByTracing( const py::function& func, - TypedStack inputs, + Stack inputs, const py::function& var_name_lookup_fn, bool force_outplace, script::Module* self = nullptr); diff --git a/torch/csrc/jit/script/class_type.cpp b/torch/csrc/jit/script/class_type.cpp index 3c1819ad81..210c7b9ab0 100644 --- a/torch/csrc/jit/script/class_type.cpp +++ b/torch/csrc/jit/script/class_type.cpp @@ -65,17 +65,21 @@ size_t ClassType::addAttribute( const std::string& name, TypePtr type, bool is_parameter) { + const char* what = is_parameter ? "parameter" : "attribute"; for (size_t i = 0; i < attributeNames_.size(); ++i) { TORCH_CHECK( name != attributeNames_[i], "attempting to add ", - is_parameter ? "parameter" - : "attribute" - " '", + what, + " '", name, - "' but a field of the same name already exists with type ", + "' to ", + python_str(), + " but a field of the same name already exists with type ", attributeTypes_[i]->python_str()); } + checkNoAny(*this, what, name, type); + size_t slot = attributeNames_.size(); attributeNames_.push_back(name); attributeTypes_.push_back(type); diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 54bc4c97ff..19b493e662 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -599,7 +599,7 @@ void initJitScriptBindings(PyObject* module) { bool force_outplace) { // prereq: Module's buffers and parameters are unique // this was ensured in python before calling this function - auto typed_inputs = toTypedStack(input_tuple); + auto typed_inputs = toTraceableStack(input_tuple); auto graph = tracer::createGraphByTracing( func, typed_inputs, var_lookup_fn, force_outplace, &self); const auto method_name = QualifiedName(self.name(), name); @@ -805,7 +805,7 @@ void initJitScriptBindings(PyObject* module) { py::tuple input_tuple, py::function var_lookup_fn, bool force_outplace) { - auto typed_inputs = toTypedStack(input_tuple); + auto typed_inputs = toTraceableStack(input_tuple); auto graph = tracer::createGraphByTracing( func, typed_inputs, var_lookup_fn, force_outplace); auto cu = get_python_cu(); diff --git a/torch/csrc/jit/script/script_type_parser.cpp b/torch/csrc/jit/script/script_type_parser.cpp index b32d7c254f..6b8c2d5198 100644 --- a/torch/csrc/jit/script/script_type_parser.cpp +++ b/torch/csrc/jit/script/script_type_parser.cpp @@ -19,6 +19,7 @@ const std::unordered_map& ident_to_type_lut() { // parsing serialized methods that use implicit converions to Scalar {"number", NumberType::get()}, {"None", NoneType::get()}, + {"Any", AnyType::get()}, }; return map; } @@ -269,21 +270,20 @@ std::vector ScriptTypeParser::parseArgsFromDecl( auto decl_arg = *it; TypePtr type; - c10::optional N; + c10::optional N = c10::nullopt; bool is_inferred_type = false; if (!decl_arg.type().present()) { // If this param doesn't have a type, default to "tensor" is_inferred_type = true; type = TensorType::get(); - N = c10::nullopt; } else { // BroadcastList list can only appear at the argument level - if (auto maybe_broad_list = parseBroadcastList(decl_arg.type().get())) { + Expr type_expr = decl_arg.type().get(); + if (auto maybe_broad_list = parseBroadcastList(type_expr)) { type = maybe_broad_list->first; N = maybe_broad_list->second; } else { type = parseTypeFromExpr(decl_arg.type().get()); - N = c10::nullopt; } } c10::optional default_value = c10::nullopt; diff --git a/torch/csrc/jit/script/sugared_value.cpp b/torch/csrc/jit/script/sugared_value.cpp index f5e425eb59..0aa64a32bb 100644 --- a/torch/csrc/jit/script/sugared_value.cpp +++ b/torch/csrc/jit/script/sugared_value.cpp @@ -212,6 +212,7 @@ void SimpleValue::setAttr( << "Classes that recursively contain instances of themselves" << " are not yet supported"; } + classType->addAttribute(field, newValue->type()); expectedType = newValue->type(); diff --git a/torch/csrc/jit/tracer.cpp b/torch/csrc/jit/tracer.cpp index 7fecd1a940..53b3f70c01 100644 --- a/torch/csrc/jit/tracer.cpp +++ b/torch/csrc/jit/tracer.cpp @@ -146,7 +146,7 @@ Value* TracingState::getValue(const IValue& var) { } return it->second; } - std::ostringstream oss; + std::ostringstream oss; if (var.isFuture()) { oss << "Tried to trace Future or Object that the tracer was not aware of."; } else { @@ -285,7 +285,7 @@ static void gatherParametersAndBuffers( Value* self_value, const script::Module& self) { Graph& g = *self_value->owningGraph(); - + state->setValue(self.module_object(), self_value); for (script::Slot s : self.get_slots()) { @@ -304,7 +304,7 @@ static void gatherParametersAndBuffers( // varied on subsequent invocations of the trace. Any other variables // will be treated as constants. std::pair, Stack> enter( - TypedStack inputs, + Stack inputs, script::Module* self) { if (isTracing()) { AT_ERROR("Tracing can't be nested"); @@ -321,12 +321,10 @@ std::pair, Stack> enter( } size_t i = 0; - auto input_types = inputs.types()->elements(); - for (IValue& input : inputs.stack()) { - input = addInput(state, - input, input_types[i++], state->graph->addInput()); + for (IValue& input : inputs) { + input = addInput(state, input, input.type(), state->graph->addInput()); } - return std::make_pair(state, inputs.stack()); + return std::make_pair(state, inputs); } // Exit a trace, treating 'outputs' as the outputs of the trace. These diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h index a8685cc4a3..30bb865904 100644 --- a/torch/csrc/jit/tracer.h +++ b/torch/csrc/jit/tracer.h @@ -205,31 +205,8 @@ TORCH_API std::function pauseTracing(); TORCH_API Value* getValueTrace(const IValue& var); -struct TypedStack : public std::pair -{ - using pair::pair; - - // NB: The inherited default constructor gives nullptr for |type|, - // so we provide a saner one. - TypedStack() - : pair({}, TupleType::create({})) - {} - - Stack& stack() { - return this->first; - } - TupleTypePtr& types() { - return this->second; - } - size_t size() { - auto s = stack().size(); - AT_ASSERT(s == types()->elements().size()); - return s; - } -}; - TORCH_API std::pair, Stack> enter( - TypedStack inputs, + Stack inputs, script::Module* self = nullptr); TORCH_API void exit(const Stack& outputs); diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index e5c91f4fed..120dfbbb8f 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -5,9 +5,10 @@ import torch from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \ BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \ - is_optional, _qualified_name + is_optional, _qualified_name, Any from torch._C import TensorType, TupleType, FloatType, IntType, \ - ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType + ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType + from textwrap import dedent from torch._six import builtins from torch._utils_internal import get_source_lines_and_file @@ -28,15 +29,6 @@ def __getattr__(self, name): raise RuntimeError("Module {} has no member called {}".format(self.name, name)) -_eval_env = { - 'torch': Module('torch', {'Tensor': torch.Tensor}), - 'Tensor': torch.Tensor, - 'typing': Module('typing', {'Tuple': Tuple}), - 'Tuple': Tuple, - 'List': List, - 'Dict': Dict, - 'Optional': Optional, -} class EvalEnv(object): env = { 'torch': Module('torch', {'Tensor': torch.Tensor}), @@ -244,6 +236,8 @@ def ann_to_type(ann, resolver=None): return StringType.get() elif ann is bool: return BoolType.get() + elif ann is Any: + return AnyType.get() elif hasattr(ann, "__torch_script_class__"): return ClassType(_qualified_name(ann)) elif hasattr(ann, "__torch_script_interface__"): @@ -258,6 +252,7 @@ def ann_to_type(ann, resolver=None): __all__ = [ + 'Any', 'List', 'BroadcastingList1', 'BroadcastingList2', @@ -274,6 +269,7 @@ def ann_to_type(ann, resolver=None): 'ListType', 'StringType', 'DictType', + 'AnyType', 'Module', # TODO: Consider not exporting these during wildcard import (reserve # that for the types; for idiomatic typing code.)