Skip to content

Commit

Permalink
Allow 'Any' to appear as a type argument. (#26572)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#26572

Combined with isinstance specialization this allows a degree of polymorphic
functions to work without needing to use our weirder overload hacks.

We do not define any operators on Any, so the only thing you can do with it
is to put it in containers or type refine it using an isinstance check.
Any is restricted from appearing in non-argument position because we
cannot restore type tags if it ends up as a field in a class.

Test Plan: Imported from OSS

Differential Revision: D17530643

Pulled By: zdevito

fbshipit-source-id: f06f78ce84819f7773953a492f3d4c49219ee94c
  • Loading branch information
zdevito authored and facebook-github-bot committed Oct 16, 2019
1 parent 97b39a2 commit fb45171
Show file tree
Hide file tree
Showing 19 changed files with 140 additions and 93 deletions.
13 changes: 13 additions & 0 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,19 @@ struct CAFFE2_API NamedType : public Type {
c10::optional<QualifiedName> name_;
};

// Any should never appear in a named type like a class, namedtuple or
// interface. If it does, then dynamic type information will be lost in the
// Pickler, leading to hard-to-track-down bugs that will only occur
// after saving or loading a model. This is because we rely on the
// static types in named types to reconstruct type tags of loaded
// values. Lifting this restriction requires solving the serialization
// problem first.
CAFFE2_API void checkNoAny(
const Type& base,
const char* what,
const std::string& attrname,
const TypePtr& attrtype);

struct TupleType;
using TupleTypePtr = std::shared_ptr<TupleType>;
using NameList = std::vector<std::string>;
Expand Down
35 changes: 35 additions & 0 deletions aten/src/ATen/core/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,11 @@ TupleType::TupleType(
std::any_of(elements_.begin(), elements_.end(), [](TypePtr v) {
return v->hasFreeVariables();
});
if (schema_) {
for (const Argument& arg : schema_->arguments()) {
checkNoAny(*this, "attribute", arg.name(), arg.type());
}
}
}

bool TupleType::isSubtypeOfExt(const TypePtr rhs_, std::ostream* why_not) const {
Expand Down Expand Up @@ -718,4 +723,34 @@ InterfaceType::InterfaceType(QualifiedName name)

InterfaceType::~InterfaceType() = default;


static bool containsAny(const TypePtr& type) {
std::vector<TypePtr> to_scan = { type };
while (!to_scan.empty()) {
TypePtr typ = to_scan.back();
to_scan.pop_back();
if (typ->kind() == AnyType::Kind) {
return true;
}
for (const TypePtr& sub : typ->containedTypes()) {
to_scan.emplace_back(sub);
}
}
return false;
}

void checkNoAny(const Type& base, const char* what, const std::string& attrname, const TypePtr& attrtype) {
TORCH_CHECK(
!containsAny(attrtype),
"attempting to add ",
what,
" '",
attrname,
"' of type ",
attrtype->python_str(),
" to '",
base.python_str(),
"' but it contains an Any type. Any types cannot be members of modules, classes, or named tuples.");
}

} // namespace c10
3 changes: 1 addition & 2 deletions test/cpp/jit/test_autodiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ std::shared_ptr<Graph> trace(
auto input_typeptr = TupleType::create(std::move(input_types));
std::shared_ptr<tracer::TracingState> state;
Stack trace_stack_in;
std::tie(state, trace_stack_in) =
tracer::enter(tracer::TypedStack(input_vars, input_typeptr));
std::tie(state, trace_stack_in) = tracer::enter(input_vars);
variable_list trace_vars_in = fmap(
trace_stack_in, [](const IValue& v) { return Variable(v.toTensor()); });
auto trace_vars_out = test(trace_vars_in);
Expand Down
23 changes: 22 additions & 1 deletion test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch._C import TensorType, BoolType, parse_ir, _propagate_shapes
from torch._six import inf, PY2, PY37, StringIO
from torch.autograd import Variable, Function
from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401
from torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401
from torch.jit.frontend import NotSupportedError
from torch.onnx import OperatorExportTypes
from torch.testing import FileCheck
Expand Down Expand Up @@ -6987,6 +6987,27 @@ def foo2(a, b):
self.assertEqual(foo2(None, 4), 0)
self.assertEqual(foo2(4, None), 0)

@torch.jit.script
def any_refinement(a, b):
# type: (Any, Any) -> int
if isinstance(a, int) and isinstance(b, int):
return a + b
return 0

self.assertEqual(any_refinement(3, 4), 7)
self.assertEqual(any_refinement(3, "hi"), 0)

def test_any_in_class_fails(self):
with self.assertRaisesRegex(RuntimeError, "contains an Any"):
@torch.jit.script
class Foo(object):
def __init__(self, a):
# type: (Tuple[int,Any]) -> None
self.a = a

def hi(self):
pass

def test_isinstance(self):
# test isinstance operator for static type checking
template = dedent('''
Expand Down
14 changes: 13 additions & 1 deletion test/test_jit_py3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from common_utils import run_tests
from jit_utils import JitTestCase
from torch.testing import FileCheck
from typing import NamedTuple, List, Optional
from typing import NamedTuple, List, Optional, Any
import unittest
import sys
import torch
Expand Down Expand Up @@ -230,5 +230,17 @@ def foo():
x : Optional[int] = 7


def test_any_in_class_fails(self):
class MyCoolNamedTuple(NamedTuple):
a : Any
b : float
c : List[int]
with self.assertRaisesRegex(RuntimeError, "contains an Any"):
@torch.jit.script
def foo():
return MyCoolNamedTuple(4, 5.5, [3])
print(foo.graph)


if __name__ == '__main__':
run_tests()
6 changes: 5 additions & 1 deletion torch/_jit_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def _get_overloaded_methods(method, mod_class):

try:
import typing
from typing import Tuple, List, Dict, Optional
from typing import Tuple, List, Dict, Optional, Any

def is_tuple(ann):
# For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule
Expand Down Expand Up @@ -607,10 +607,14 @@ class OptionalCls(object):
def __getitem__(self, types):
return OptionalInstance(types)

class AnyCls(object):
pass

Tuple = TupleCls() # noqa: T484
List = ListCls() # noqa: T484
Dict = DictCls() # noqa: T484
Optional = DictCls() # noqa: T484
Any = AnyCls() # noqa: T484

def is_tuple(ann):
return isinstance(ann, TupleInstance)
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ void initJITBindings(PyObject* module) {
Stack stack;
stack.reserve(inputs.size()); // captures?
for (auto& obj : inputs) {
stack.push_back(toIValue(obj));
stack.push_back(toTypeInferredIValue(obj));
}
ArgumentSpec spec = arg_spec_creator.create(with_grad, stack);
arg_spec_creator.specializeTypes(*graph, spec);
Expand Down Expand Up @@ -314,7 +314,7 @@ void initJITBindings(PyObject* module) {
[](std::shared_ptr<Graph> g,
py::tuple args,
const std::string& unqualified_op_name) {
auto stack = toStack(args);
auto stack = toTraceableStack(args);
checkAliasAnnotation(g, std::move(stack), unqualified_op_name);
})
.def(
Expand Down Expand Up @@ -535,7 +535,7 @@ void initJITBindings(PyObject* module) {
// Convert the output of the user-supplied funciton to IValue. The type
// information of this IValue is used both to record the correct type in
// the trace.
output_ivalue = toIValue(py_func_output);
output_ivalue = toTypeInferredIValue(py_func_output);
Value* out_val = jit::tracer::getValueTrace(output_ivalue);
body_block->registerOutput(out_val);
node_output =
Expand All @@ -556,7 +556,7 @@ void initJITBindings(PyObject* module) {

return PythonFutureWrapper(retval);
} else {
auto result = toIValue(f(*args_tup));
auto result = toTypeInferredIValue(f(*args_tup));
auto retval = c10::make_intrusive<c10::ivalue::Future>(result.type());
retval->markCompleted(std::move(result));
return PythonFutureWrapper(retval);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/pybind.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct type_caster<torch::jit::IValue> {

bool load(handle src, bool) {
try {
value = torch::jit::toIValue(src);
value = torch::jit::toTypeInferredIValue(src);
return true;
} catch (std::exception& e) {
return false;
Expand Down
35 changes: 10 additions & 25 deletions torch/csrc/jit/pybind_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ namespace jit {
// that is confusing to display to the end user since it always reports
// locations in libtorch code rather than user code.

using tracer::TypedStack;

inline std::shared_ptr<script::CompilationUnit> get_python_cu() {
return py::module::import("torch.jit")
.attr("_python_cu")
Expand Down Expand Up @@ -287,37 +285,24 @@ inline bool isTraceableType(TypePtr type) {
return false;
}

inline TypedIValue toTraceableIValue(py::handle input) {
inline IValue toTypeInferredIValue(py::handle input) {
auto match = tryToInferType(input);
if (!match.success()) {
AT_ERROR(
"Tracer cannot infer type of ", py::str(input), "\n:", match.reason());
}
auto type = match.type();

if (isTraceableType(type)) {
return TypedIValue(toIValue(input, type), type);
}
return toIValue(input, match.type());
}

AT_ERROR(
inline Stack toTraceableStack(const py::tuple& inputs) {
auto info = toTypeInferredIValue(inputs);
AT_CHECK(
isTraceableType(info.type()),
"Type '",
type->python_str(),
info.type()->python_str(),
"' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and"
" Tuples of Tensors can be traced");
}

inline IValue toIValue(py::handle input) {
return toTraceableIValue(input).ivalue();
}

inline Stack toStack(const py::tuple& inputs) {
return toIValue(inputs).toTuple()->elements();
}

inline TypedStack toTypedStack(const py::tuple& inputs) {
auto info = toTraceableIValue(inputs);
return TypedStack(
info.ivalue().toTuple()->elements(), info.type()->expect<TupleType>());
return info.toTuple()->elements();
}

inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {
Expand Down Expand Up @@ -545,7 +530,7 @@ inline IValue toIValue(
case TypeKind::CapsuleType:
AT_ERROR("Capsule Values aren't supported");
case TypeKind::AnyType:
AT_ERROR("AnyType Values aren't supported");
return toTypeInferredIValue(obj);
}
AT_ERROR(
"Missing cases in toIValue for type: ",
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/python_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,8 @@ void initPythonIRBindings(PyObject* module_) {
return self->isSubtypeOf(other);
});

py::class_<AnyType, Type, std::shared_ptr<AnyType>>(m, "AnyType")
.def_static("get", &AnyType::get);
py::class_<NumberType, Type, std::shared_ptr<NumberType>>(m, "NumberType")
.def_static("get", &NumberType::get);
py::class_<IntType, Type, std::shared_ptr<IntType>>(m, "IntType")
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/python_tracer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ SourceRange getPythonInterpreterSourceRange() {

std::shared_ptr<torch::jit::Graph> createGraphByTracing(
const py::function& func,
TypedStack trace_inputs,
Stack trace_inputs,
const py::function& var_name_lookup_fn,
bool force_outplace,
script::Module* self) {
Expand All @@ -78,7 +78,7 @@ std::shared_ptr<torch::jit::Graph> createGraphByTracing(
"The traced function didn't return any values! Side-effects are not "
"captured in traces, so it would be a no-op.");
}
tracer::exit({toIValue(out)});
tracer::exit({toTypeInferredIValue(out)});
if (script::getInlineEverythingMode()) {
Inline(*graph);
}
Expand Down Expand Up @@ -161,10 +161,10 @@ void initPythonTracerBindings(PyObject* module) {

m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); });
m.def("_tracer_enter", [](py::args trace_inputs) {
return tracer::enter(toTypedStack(trace_inputs));
return tracer::enter(toTraceableStack(trace_inputs));
});
m.def("_tracer_exit", [](py::tuple var_outputs) {
tracer::exit(toStack(var_outputs));
tracer::exit(toTraceableStack(var_outputs));
});
m.def("_tracer_abandon", []() { tracer::abandon(); });
m.def("_get_tracing_state", []() { return getTracingState(); });
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/python_tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Node* preRecordPythonTrace(

std::shared_ptr<Graph> createGraphByTracing(
const py::function& func,
TypedStack inputs,
Stack inputs,
const py::function& var_name_lookup_fn,
bool force_outplace,
script::Module* self = nullptr);
Expand Down
12 changes: 8 additions & 4 deletions torch/csrc/jit/script/class_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,21 @@ size_t ClassType::addAttribute(
const std::string& name,
TypePtr type,
bool is_parameter) {
const char* what = is_parameter ? "parameter" : "attribute";
for (size_t i = 0; i < attributeNames_.size(); ++i) {
TORCH_CHECK(
name != attributeNames_[i],
"attempting to add ",
is_parameter ? "parameter"
: "attribute"
" '",
what,
" '",
name,
"' but a field of the same name already exists with type ",
"' to ",
python_str(),
" but a field of the same name already exists with type ",
attributeTypes_[i]->python_str());
}
checkNoAny(*this, what, name, type);

size_t slot = attributeNames_.size();
attributeNames_.push_back(name);
attributeTypes_.push_back(type);
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/script/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ void initJitScriptBindings(PyObject* module) {
bool force_outplace) {
// prereq: Module's buffers and parameters are unique
// this was ensured in python before calling this function
auto typed_inputs = toTypedStack(input_tuple);
auto typed_inputs = toTraceableStack(input_tuple);
auto graph = tracer::createGraphByTracing(
func, typed_inputs, var_lookup_fn, force_outplace, &self);
const auto method_name = QualifiedName(self.name(), name);
Expand Down Expand Up @@ -805,7 +805,7 @@ void initJitScriptBindings(PyObject* module) {
py::tuple input_tuple,
py::function var_lookup_fn,
bool force_outplace) {
auto typed_inputs = toTypedStack(input_tuple);
auto typed_inputs = toTraceableStack(input_tuple);
auto graph = tracer::createGraphByTracing(
func, typed_inputs, var_lookup_fn, force_outplace);
auto cu = get_python_cu();
Expand Down
Loading

0 comments on commit fb45171

Please sign in to comment.