diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni.cpp index 7f09b51a51..de940b6299 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni.cpp @@ -506,7 +506,7 @@ class JIValue : public facebook::jni::JavaClass { auto jivalue_first_element = jarray->getElement(0); auto first_element = JIValue::JIValueToAtIValue(jivalue_first_element); - c10::TypePtr typePtr = c10::attemptToRecoverType(first_element); + c10::TypePtr typePtr = first_element.type(); c10::impl::GenericList list{typePtr}; list.reserve(n); list.push_back(first_element); @@ -529,7 +529,7 @@ class JIValue : public facebook::jni::JavaClass { } auto firstEntryValue = JIValue::JIValueToAtIValue(it->second); - c10::TypePtr typePtr = c10::attemptToRecoverType(firstEntryValue); + c10::TypePtr typePtr = firstEntryValue.type(); c10::impl::GenericDict dict{c10::StringType::get(), typePtr}; dict.insert(it->first->toStdString(), firstEntryValue); it++; @@ -552,7 +552,7 @@ class JIValue : public facebook::jni::JavaClass { } auto firstEntryValue = JIValue::JIValueToAtIValue(it->second); - c10::TypePtr typePtr = c10::attemptToRecoverType(firstEntryValue); + c10::TypePtr typePtr = firstEntryValue.type(); c10::impl::GenericDict dict{c10::IntType::get(), typePtr}; dict.insert(it->first->longValue(), firstEntryValue); it++; diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h index 48a0c9579f..18855f6ab9 100644 --- a/aten/src/ATen/core/function_schema_inl.h +++ b/aten/src/ATen/core/function_schema_inl.h @@ -186,12 +186,12 @@ inline void FunctionSchema::checkArg( const IValue& value, const Argument& argument, optional pos) const { - if (!isSubvalueOf(value, argument.type())) { + if (!value.type()->isSubtypeOf(argument.type())) { std::string position = pos ? ::c10::str(" in position ", *pos) : ""; TORCH_CHECK( false, formatTypeMismatchMsg( - argument, attemptToRecoverType(value)->python_str(), pos)); + argument, value.type()->python_str(), pos)); } } diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 498f4bc355..454cae418a 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -1317,10 +1317,6 @@ inline TypePtr getTypePtr() { return detail::getTypePtr_::call(); } -CAFFE2_API TypePtr incompleteInferTypeFrom(const IValue& value); -CAFFE2_API TypePtr attemptToRecoverType(const IValue& input_ivalue); -CAFFE2_API bool isSubvalueOf(const IValue& input_ivalue, TypePtr type); - using TypeEnv = std::unordered_map; struct MatchTypeReturn { MatchTypeReturn(std::string reason) : reason_(std::move(reason)) {} diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 97cc9111ef..b75810acd3 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -137,117 +137,6 @@ ListTypePtr ListType::ofBools() { return value; } -// why incomplete? You cannot completely recover a type from -// an IValue, List[List[int]] and List[List[Tensor]] will both -// become ivalue.isGenericList() and cannot be recovered. -// The only appropriate place to use this is where you know that -// you are only dealing with a subset of objects where you can recover -// the type, like in the tracer. -TypePtr incompleteInferTypeFrom(const IValue& value) { - if (value.isTensor()) { - return TensorType::create(value.toTensor()); - } else if (value.isDouble()) { - return FloatType::get(); - } else if (value.isInt()) { - return IntType::get(); - } else if (value.isBool()) { - return BoolType::get(); - } else if (value.isString()) { - return StringType::get(); - } else if (value.isIntList()) { - return ListType::ofInts(); - } else if (value.isTensorList()) { - return ListType::ofTensors(); - } else if (value.isBoolList()) { - return ListType::ofBools(); - } else if (value.isDoubleList()) { - return ListType::ofFloats(); - } else if (value.isTuple()) { - return TupleType::create(fmap(value.toTuple()->elements(), incompleteInferTypeFrom)); - } else if (value.isDevice()) { - return DeviceObjType::get(); - } else if (value.isObject()) { - return value.toObject()->type(); - } - AT_ERROR("Type cannot be accurately recovered from this IValue."); -} - -// This attempts to recover the type from an IValue, including nested Generic -// Lists. It only examines the first element (the first of the iterator in the -// case of a dict) of each generic container, -// and if a generic container is empty returns typevar as the base element. -// XXX: only used for better error messages, should not be used elsewhere -TypePtr attemptToRecoverType(const IValue& ivalue) { - if (ivalue.isGenericList()) { - auto ivalue_list = ivalue.toGenericListRef(); - if (ivalue_list.size() == 0) { - return ListType::create(VarType::create("t")); - } - return ListType::create(attemptToRecoverType(ivalue_list[0])); - } - if (ivalue.isGenericDict()) { - auto dict = ivalue.toGenericDict(); - if (dict.size() == 0) { - return DictType::create(VarType::create("t"), VarType::create("t")); - } - auto item = dict.begin(); - return DictType::create( - attemptToRecoverType(item->key()), attemptToRecoverType(item->value())); - } - return incompleteInferTypeFrom(ivalue); -} - -// Checks if input_ivalue is a subvalue of type. -bool isSubvalueOf(const IValue& ivalue, TypePtr type) { - if (auto optional = type->cast()) { - // Unwrap the optional if the ivalue is not none - if (ivalue.isNone()) { - return true; - } else { - return isSubvalueOf(ivalue, optional->getElementType()); - } - } - - if (ivalue.isTuple()) { - auto elems = ivalue.toTuple()->elements(); - auto tuple_type = type->cast(); - if (!tuple_type || tuple_type->elements().size() != elems.size()) { - return false; - } - auto type_elem = tuple_type->elements(); - bool is_subvalue = true; - for (size_t i = 0; i < type_elem.size() && is_subvalue; ++i) { - is_subvalue = isSubvalueOf(elems[i], type_elem[i]); - } - return is_subvalue; - } - if (ivalue.isGenericList()) { - auto list_type = type->cast(); - if (!list_type) { - return false; - } - auto ivalue_list = ivalue.toGenericListRef(); - auto element_type = list_type->getElementType(); - return std::all_of(ivalue_list.begin(), ivalue_list.end(), [&](const IValue& list_elem) { - return isSubvalueOf(list_elem, element_type); - }); - } - if (ivalue.isGenericDict()) { - auto dict_type = type->expect(); - const auto dict = ivalue.toGenericDict(); - return std::all_of( - dict.begin(), dict.end(), [=](const c10::impl::GenericDict::iterator::value_type& item) { - return isSubvalueOf(item.key(), dict_type->getKeyType()) && - isSubvalueOf(item.value(), dict_type->getValueType()); - }); - } - if (ivalue.isObject()) { - return ivalue.toObjectRef().type()->isSubtypeOf(type); - } - - return incompleteInferTypeFrom(ivalue)->isSubtypeOf(type); -} - c10::optional tryEitherIsTheSuperType(const TypePtr& t1, const TypePtr& t2) { if (t1->isSubtypeOf(t2)) { return t2; diff --git a/test/cpp/api/jit.cpp b/test/cpp/api/jit.cpp index 073b7f2f50..fa235788d1 100644 --- a/test/cpp/api/jit.cpp +++ b/test/cpp/api/jit.cpp @@ -65,7 +65,7 @@ TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) { std::string(error.what_without_backtrace()) .find("nested_loop() Expected a value of type 'List[List[Tensor]]'" " for argument 'a' but instead found type " - "'List[List[List[t]]]'") == 0); + "'List[List[List[Tensor]]]'") == 0); }; } diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index d8d49807dc..56e8f2fb42 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -129,7 +129,16 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) { // XXX: Do not optimize __setstate__, so that we don't try to // specialize the class before it is initialized. setGraphExecutorOptimize(false); - (*type.type_->getMethod("__setstate__"))({obj, input}); + Function* set_state = type.type_->getMethod("__setstate__"); + // since we are in the middle of unpickling we might still have lists and + // dicts that do not have accurate tags (e.g. they report they are + // List[Any]). But we need to run __setstate__ which will check the input + // type and may access the tags. Since setstate has a known input type, we + // can correctly restore the tags now by apply the input type of set_state + // to the state object being passed. + restoreAccurateTypeTags( + input, set_state->getSchema().arguments().at(1).type()); + (*set_state)({obj, input}); setGraphExecutorOptimize(true); postSetStateValidate(obj); return obj; diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 71eb5f0305..a72b60bcb3 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -253,7 +253,7 @@ void initJITBindings(PyObject* module) { auto g_inputs = graph->inputs(); for (size_t i = 0; i < inputs.size(); ++i) { if (stack[i].isTensor()) { - g_inputs[i]->setType(incompleteInferTypeFrom(stack[i])); + g_inputs[i]->setType(stack[i].type()); } } PropagateInputShapes(graph); diff --git a/torch/csrc/jit/unpickler.cpp b/torch/csrc/jit/unpickler.cpp index 5295c5f5aa..2d3e72d946 100644 --- a/torch/csrc/jit/unpickler.cpp +++ b/torch/csrc/jit/unpickler.cpp @@ -32,26 +32,9 @@ PicklerClass getClass(const std::string& str) { AT_ERROR("Unknown class name for unpickler: ", str); } -static void postSetStateValidate(const IValue& v) { - auto obj = v.toObject(); - const auto& objType = obj->type(); - for (size_t i = 0; i < objType->numAttributes(); i++) { - const auto& attrType = objType->getAttribute(i); - const auto& attrName = objType->getAttributeName(i); - const auto& slot = obj->getSlot(i); - // const auto attrType = objType->getAttribute(i); - // Verify that all the non-optional attributes have been initialized - // TODO: Issue #20497 - if (attrType->kind() != TypeKind::OptionalType) { - TORCH_CHECK( - !slot.isNone(), - "The field '", - attrName, - "' was left unitialized after __setstate__, but expected a ", - "value of type '", - attrType->python_str(), - "'"); - } +static void restoreAccurateTypeTagsIfPossible(const IValue& root) { + if (root.isObject()) { + restoreAccurateTypeTags(root, root.type()); } } @@ -66,15 +49,12 @@ static void postSetStateValidate(const IValue& v) { // the top-level unpickled thing (which is guarenteed for Modules, but // not for torch.load/torch,save). Otherwise we do not know the types // of the contained objects and cannot restore the tags. -static void restoreAccurateTypeTagsIfPossible(const IValue& root) { - if (!root.isObject()) { - return; - } +void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) { struct Work { TypePtr static_type; IValue value; }; - std::vector to_process = {{root.type(), root}}; + std::vector to_process = {{type_tag, root}}; std::unordered_set scanned; while (!to_process.empty()) { Work w = std::move(to_process.back()); diff --git a/torch/csrc/jit/unpickler.h b/torch/csrc/jit/unpickler.h index ab2a11e6cc..0b46ae5ab4 100644 --- a/torch/csrc/jit/unpickler.h +++ b/torch/csrc/jit/unpickler.h @@ -101,5 +101,7 @@ class Unpickler { c10::optional device_; }; +void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag); + } // namespace jit } // namespace torch