Skip to content

Commit

Permalink
Remove attempToRecoverType (#26767)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#26767

Now that we have tagged ivalues, we can accurately recover the type with
`ivalue.type()`. This reomoves the other half-implemented pathways that
were created because we didn't have tags.

Test Plan: Imported from OSS

Differential Revision: D17561191

Pulled By: zdevito

fbshipit-source-id: 26aaa134099e75659a230d8a5a34a86dc39a3c5c
  • Loading branch information
zdevito authored and facebook-github-bot committed Oct 16, 2019
1 parent fb45171 commit 5136ed0
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 148 deletions.
6 changes: 3 additions & 3 deletions android/pytorch_android/src/main/cpp/pytorch_jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {

auto jivalue_first_element = jarray->getElement(0);
auto first_element = JIValue::JIValueToAtIValue(jivalue_first_element);
c10::TypePtr typePtr = c10::attemptToRecoverType(first_element);
c10::TypePtr typePtr = first_element.type();
c10::impl::GenericList list{typePtr};
list.reserve(n);
list.push_back(first_element);
Expand All @@ -529,7 +529,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
}

auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
c10::TypePtr typePtr = c10::attemptToRecoverType(firstEntryValue);
c10::TypePtr typePtr = firstEntryValue.type();
c10::impl::GenericDict dict{c10::StringType::get(), typePtr};
dict.insert(it->first->toStdString(), firstEntryValue);
it++;
Expand All @@ -552,7 +552,7 @@ class JIValue : public facebook::jni::JavaClass<JIValue> {
}

auto firstEntryValue = JIValue::JIValueToAtIValue(it->second);
c10::TypePtr typePtr = c10::attemptToRecoverType(firstEntryValue);
c10::TypePtr typePtr = firstEntryValue.type();
c10::impl::GenericDict dict{c10::IntType::get(), typePtr};
dict.insert(it->first->longValue(), firstEntryValue);
it++;
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/core/function_schema_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,12 @@ inline void FunctionSchema::checkArg(
const IValue& value,
const Argument& argument,
optional<size_t> pos) const {
if (!isSubvalueOf(value, argument.type())) {
if (!value.type()->isSubtypeOf(argument.type())) {
std::string position = pos ? ::c10::str(" in position ", *pos) : "";
TORCH_CHECK(
false,
formatTypeMismatchMsg(
argument, attemptToRecoverType(value)->python_str(), pos));
argument, value.type()->python_str(), pos));
}
}

Expand Down
4 changes: 0 additions & 4 deletions aten/src/ATen/core/jit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -1317,10 +1317,6 @@ inline TypePtr getTypePtr() {
return detail::getTypePtr_<T>::call();
}

CAFFE2_API TypePtr incompleteInferTypeFrom(const IValue& value);
CAFFE2_API TypePtr attemptToRecoverType(const IValue& input_ivalue);
CAFFE2_API bool isSubvalueOf(const IValue& input_ivalue, TypePtr type);

using TypeEnv = std::unordered_map<std::string, TypePtr>;
struct MatchTypeReturn {
MatchTypeReturn(std::string reason) : reason_(std::move(reason)) {}
Expand Down
111 changes: 0 additions & 111 deletions aten/src/ATen/core/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,117 +137,6 @@ ListTypePtr ListType::ofBools() {
return value;
}

// why incomplete? You cannot completely recover a type from
// an IValue, List[List[int]] and List[List[Tensor]] will both
// become ivalue.isGenericList() and cannot be recovered.
// The only appropriate place to use this is where you know that
// you are only dealing with a subset of objects where you can recover
// the type, like in the tracer.
TypePtr incompleteInferTypeFrom(const IValue& value) {
if (value.isTensor()) {
return TensorType::create(value.toTensor());
} else if (value.isDouble()) {
return FloatType::get();
} else if (value.isInt()) {
return IntType::get();
} else if (value.isBool()) {
return BoolType::get();
} else if (value.isString()) {
return StringType::get();
} else if (value.isIntList()) {
return ListType::ofInts();
} else if (value.isTensorList()) {
return ListType::ofTensors();
} else if (value.isBoolList()) {
return ListType::ofBools();
} else if (value.isDoubleList()) {
return ListType::ofFloats();
} else if (value.isTuple()) {
return TupleType::create(fmap(value.toTuple()->elements(), incompleteInferTypeFrom));
} else if (value.isDevice()) {
return DeviceObjType::get();
} else if (value.isObject()) {
return value.toObject()->type();
}
AT_ERROR("Type cannot be accurately recovered from this IValue.");
}

// This attempts to recover the type from an IValue, including nested Generic
// Lists. It only examines the first element (the first of the iterator in the
// case of a dict) of each generic container,
// and if a generic container is empty returns typevar as the base element.
// XXX: only used for better error messages, should not be used elsewhere
TypePtr attemptToRecoverType(const IValue& ivalue) {
if (ivalue.isGenericList()) {
auto ivalue_list = ivalue.toGenericListRef();
if (ivalue_list.size() == 0) {
return ListType::create(VarType::create("t"));
}
return ListType::create(attemptToRecoverType(ivalue_list[0]));
}
if (ivalue.isGenericDict()) {
auto dict = ivalue.toGenericDict();
if (dict.size() == 0) {
return DictType::create(VarType::create("t"), VarType::create("t"));
}
auto item = dict.begin();
return DictType::create(
attemptToRecoverType(item->key()), attemptToRecoverType(item->value()));
}
return incompleteInferTypeFrom(ivalue);
}

// Checks if input_ivalue is a subvalue of type.
bool isSubvalueOf(const IValue& ivalue, TypePtr type) {
if (auto optional = type->cast<OptionalType>()) {
// Unwrap the optional if the ivalue is not none
if (ivalue.isNone()) {
return true;
} else {
return isSubvalueOf(ivalue, optional->getElementType());
}
}

if (ivalue.isTuple()) {
auto elems = ivalue.toTuple()->elements();
auto tuple_type = type->cast<TupleType>();
if (!tuple_type || tuple_type->elements().size() != elems.size()) {
return false;
}
auto type_elem = tuple_type->elements();
bool is_subvalue = true;
for (size_t i = 0; i < type_elem.size() && is_subvalue; ++i) {
is_subvalue = isSubvalueOf(elems[i], type_elem[i]);
}
return is_subvalue;
}
if (ivalue.isGenericList()) {
auto list_type = type->cast<ListType>();
if (!list_type) {
return false;
}
auto ivalue_list = ivalue.toGenericListRef();
auto element_type = list_type->getElementType();
return std::all_of(ivalue_list.begin(), ivalue_list.end(), [&](const IValue& list_elem) {
return isSubvalueOf(list_elem, element_type);
});
}
if (ivalue.isGenericDict()) {
auto dict_type = type->expect<DictType>();
const auto dict = ivalue.toGenericDict();
return std::all_of(
dict.begin(), dict.end(), [=](const c10::impl::GenericDict::iterator::value_type& item) {
return isSubvalueOf(item.key(), dict_type->getKeyType()) &&
isSubvalueOf(item.value(), dict_type->getValueType());
});
}
if (ivalue.isObject()) {
return ivalue.toObjectRef().type()->isSubtypeOf(type);
}

return incompleteInferTypeFrom(ivalue)->isSubtypeOf(type);
}

c10::optional<TypePtr> tryEitherIsTheSuperType(const TypePtr& t1, const TypePtr& t2) {
if (t1->isSubtypeOf(t2)) {
return t2;
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/api/jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) {
std::string(error.what_without_backtrace())
.find("nested_loop() Expected a value of type 'List[List[Tensor]]'"
" for argument 'a' but instead found type "
"'List[List[List[t]]]'") == 0);
"'List[List[List[Tensor]]]'") == 0);
};
}

Expand Down
11 changes: 10 additions & 1 deletion torch/csrc/jit/import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,16 @@ IValue ScriptModuleDeserializer::readArchive(const std::string& archive_name) {
// XXX: Do not optimize __setstate__, so that we don't try to
// specialize the class before it is initialized.
setGraphExecutorOptimize(false);
(*type.type_->getMethod("__setstate__"))({obj, input});
Function* set_state = type.type_->getMethod("__setstate__");
// since we are in the middle of unpickling we might still have lists and
// dicts that do not have accurate tags (e.g. they report they are
// List[Any]). But we need to run __setstate__ which will check the input
// type and may access the tags. Since setstate has a known input type, we
// can correctly restore the tags now by apply the input type of set_state
// to the state object being passed.
restoreAccurateTypeTags(
input, set_state->getSchema().arguments().at(1).type());
(*set_state)({obj, input});
setGraphExecutorOptimize(true);
postSetStateValidate(obj);
return obj;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ void initJITBindings(PyObject* module) {
auto g_inputs = graph->inputs();
for (size_t i = 0; i < inputs.size(); ++i) {
if (stack[i].isTensor()) {
g_inputs[i]->setType(incompleteInferTypeFrom(stack[i]));
g_inputs[i]->setType(stack[i].type());
}
}
PropagateInputShapes(graph);
Expand Down
30 changes: 5 additions & 25 deletions torch/csrc/jit/unpickler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,26 +32,9 @@ PicklerClass getClass(const std::string& str) {
AT_ERROR("Unknown class name for unpickler: ", str);
}

static void postSetStateValidate(const IValue& v) {
auto obj = v.toObject();
const auto& objType = obj->type();
for (size_t i = 0; i < objType->numAttributes(); i++) {
const auto& attrType = objType->getAttribute(i);
const auto& attrName = objType->getAttributeName(i);
const auto& slot = obj->getSlot(i);
// const auto attrType = objType->getAttribute(i);
// Verify that all the non-optional attributes have been initialized
// TODO: Issue #20497
if (attrType->kind() != TypeKind::OptionalType) {
TORCH_CHECK(
!slot.isNone(),
"The field '",
attrName,
"' was left unitialized after __setstate__, but expected a ",
"value of type '",
attrType->python_str(),
"'");
}
static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
if (root.isObject()) {
restoreAccurateTypeTags(root, root.type());
}
}

Expand All @@ -66,15 +49,12 @@ static void postSetStateValidate(const IValue& v) {
// the top-level unpickled thing (which is guarenteed for Modules, but
// not for torch.load/torch,save). Otherwise we do not know the types
// of the contained objects and cannot restore the tags.
static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
if (!root.isObject()) {
return;
}
void restoreAccurateTypeTags(const IValue& root, const TypePtr& type_tag) {
struct Work {
TypePtr static_type;
IValue value;
};
std::vector<Work> to_process = {{root.type(), root}};
std::vector<Work> to_process = {{type_tag, root}};
std::unordered_set<const void*> scanned;
while (!to_process.empty()) {
Work w = std::move(to_process.back());
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/unpickler.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,7 @@ class Unpickler {
c10::optional<at::Device> device_;
};

void restoreAccurateTypeTags(const IValue& root, const c10::TypePtr& type_tag);

} // namespace jit
} // namespace torch

0 comments on commit 5136ed0

Please sign in to comment.