Skip to content

Commit

Permalink
Use static type information to restore type tags (#25447)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#25447

When we unpickle IValues, we lose type information for List[T]
and Dict[K, V]. We can restore this information using the static
type information contained in the top-level Module/Class type.

This ensures that even after serialization we can always get the
dynamic type of an ivalue using its type() method.

Test Plan: Imported from OSS

Differential Revision: D17127872

Pulled By: zdevito

fbshipit-source-id: 1ffb5e37a7c35c71ac9d3fb7b2edbc7ce3fbec72
  • Loading branch information
zdevito authored and facebook-github-bot committed Sep 18, 2019
1 parent ad0af11 commit 12762cd
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 5 deletions.
12 changes: 12 additions & 0 deletions aten/src/ATen/core/Dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,18 @@ class Dict final {
// instead of optional<TypePtr> once types are mandatory.
TypePtr keyType() const;
TypePtr valueType() const;

// [unsafe set type]
// These functions mutate the tagged type of this dictionary in place.
// There is no checking that the members of the dictionary are instances
// of the new types, nor is there a check that other IValues which
// hold references to this dictionary have the right static type.
// This functionality is used only in the unpickler, where at
// creation type the real type of the dictionary is unknown, but
// then later recovered from the static type information of the
// unpickled object.
void unsafeSetKeyType(TypePtr t);
void unsafeSetValueType(TypePtr t);
};

namespace impl {
Expand Down
9 changes: 9 additions & 0 deletions aten/src/ATen/core/Dict_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,5 +191,14 @@ template<class Key, class Value>
TypePtr Dict<Key, Value>::valueType() const {
return impl_->elementTypes.valueType;
}
template <class Key, class Value>
void Dict<Key, Value>::unsafeSetKeyType(TypePtr t) {
impl_->elementTypes.keyType = std::move(t);
}

template <class Key, class Value>
void Dict<Key, Value>::unsafeSetValueType(TypePtr t) {
impl_->elementTypes.valueType = std::move(t);
}

}
3 changes: 3 additions & 0 deletions aten/src/ATen/core/List.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,9 @@ class List final {

TypePtr elementType() const;

// See [unsafe set type] for why this exists.
void unsafeSetElementType(TypePtr t);

private:
explicit List(c10::intrusive_ptr<detail::ListImpl<StorageT>>&& elements);
friend struct IValue;
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/core/List_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,13 @@ size_t List<T>::use_count() const {
return impl_.use_count();
}

template<class T>
template <class T>
TypePtr List<T>::elementType() const {
return impl_->elementType;
}

template <class T>
void List<T>::unsafeSetElementType(TypePtr t) {
impl_->elementType = std::move(t);
}
}
3 changes: 3 additions & 0 deletions test/jit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def extract_files(buffer):
for a, b in zip(code_files, code_files_2):
self.assertMultiLineEqual(a, b)

if isinstance(m, torch._C.ScriptModule):
self.assertTrue(torch._C._ivalue_tags_match(m, imported._c))


def emitFunctionHook(self, func):
# func has invalid names for export, skip the jitter check
Expand Down
124 changes: 124 additions & 0 deletions torch/csrc/jit/pickler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,12 +524,136 @@ void Pickler::pushTuple(const IValue& ivalue) {
}
}

// Pickled objects are stored in a form compatible with Python pickling.
// In torchscript List[T]/Dict[K, V] are statically typed and contain
// dynamic type tags allow T, K, and V to be recovered. But this info
// is not stored in the Python pickling information. However, we
// can recover this information from the static type of the top-level
// object being unpickled, because we have a record of the type of the
// objects it contains as attributes.
// `IfPossible` - we can only do this recovery when we have an object as
// the top-level unpickled thing (which is guarenteed for Modules, but
// not for torch.load/torch,save). Otherwise we do not know the types
// of the contained objects and cannot restore the tags.
static void restoreAccurateTypeTagsIfPossible(const IValue& root) {
if (!root.isObject()) {
return;
}
struct Work {
TypePtr static_type;
IValue value;
};
std::vector<Work> to_process = {{root.type(), root}};
std::unordered_set<const void*> scanned;
while (!to_process.empty()) {
Work w = std::move(to_process.back());
to_process.pop_back();
// ensure we only scan each pointer value once, otherwise this
// can become exponential (and if we allow recursive data in the future,
// it would not terminiate).
if (w.value.isPtrType()) {
const void* key = w.value.internalToPointer();
auto it = scanned.find(key);
if (it != scanned.end()) {
continue;
}
scanned.emplace_hint(it, key);
}
switch (w.static_type->kind()) {
case TensorType::Kind:
case NumberType::Kind:
case FloatType::Kind:
case IntType::Kind:
case NoneType::Kind:
case GeneratorType::Kind:
case BoolType::Kind:
case VarType::Kind:
case CapsuleType::Kind:
case StringType::Kind:
case FunctionType::Kind:
case DeviceObjType::Kind:
// no op, there is nothing to tag
break;
case AnyType::Kind:
// if Any type does show up, we no longer have a way to precisely
// recover the type information since the w.value may be an untagged
// List/Dict. We should prevent objects being serialized from having the
// Any type and if we do allow it in functions limit it to non-heap
// locations.
TORCH_INTERNAL_ASSERT(
false, "AnyType should not show up in the static type of objects");
case TupleType::Kind: {
auto t = w.value.toTuple();
auto ttype = w.static_type->expect<TupleType>();
for (size_t i = 0; i < ttype->containedTypes().size(); ++i) {
Work elem = {ttype->containedTypes().at(i), t->elements().at(i)};
to_process.emplace_back(std::move(elem));
}
} break;
case FutureType::Kind: {
auto f = w.value.toFuture();
auto t = w.static_type->expect<FutureType>();
if (f->completed()) {
Work elem = {t->getElementType(), f->value()};
to_process.emplace_back(std::move(elem));
}
} break;
case OptionalType::Kind: {
if (!w.value.isNone()) {
auto t = w.static_type->expect<OptionalType>();
Work elem = {t->getElementType(), w.value};
to_process.emplace_back(std::move(elem));
}
} break;
case ListType::Kind: {
// specialized lists do not need their type refined, so we can exit
// early here
if (!w.value.isGenericList()) {
break;
}
auto elem_type = w.static_type->cast<ListType>()->getElementType();
auto lst = w.value.toGenericList();
lst.unsafeSetElementType(elem_type);
for (const IValue& item : lst) {
Work elem = {elem_type, item};
to_process.emplace_back(std::move(elem));
}
} break;
case DictType::Kind: {
auto dt = w.static_type->cast<DictType>();
auto d = w.value.toGenericDict();
d.unsafeSetKeyType(dt->getKeyType());
d.unsafeSetValueType(dt->getValueType());
for (const auto& item : d) {
Work kelem = {dt->getKeyType(), item.key()};
Work velem = {dt->getValueType(), item.value()};
to_process.emplace_back(std::move(kelem));
to_process.emplace_back(std::move(velem));
}
} break;
// in both cases the dynamic type is a class, and we are going to tag with
// the dynamic type
case InterfaceType::Kind:
case ClassType::Kind: {
auto obj = w.value.toObject();
auto typ = obj->type(); // note: intentionally using the dynamic type,
// the static type is potentially less accurate
for (size_t i = 0; i < typ->numAttributes(); ++i) {
Work elem = {typ->getAttribute(i), obj->getSlot(i)};
to_process.emplace_back(std::move(elem));
}
};
}
}
}

IValue Unpickler::parse_ivalue() {
run();
TORCH_CHECK(
stack_.size() == 1,
"Unpickler expected 1 element on the stack, but found ",
stack_.size());
restoreAccurateTypeTagsIfPossible(stack_[0]);

return stack_[0];
}
Expand Down
10 changes: 6 additions & 4 deletions torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1206,11 +1206,13 @@ RegisterOperators reg(
throw std::runtime_error(
"DictConstruct must have an even number of inputs");
}
TORCH_INTERNAL_ASSERT(node->outputs().size() == 1, "DictConstruct must have exactly one output");
TORCH_INTERNAL_ASSERT(
node->outputs().size() == 1,
"DictConstruct must have exactly one output");
TypePtr output_type = node->outputs()[0]->type();
TORCH_INTERNAL_ASSERT(output_type->kind() == TypeKind::DictType, "DictConstruct output must be of Dict type.");
TypePtr key_type = static_cast<const DictType*>(output_type.get())->getKeyType();
TypePtr value_type = static_cast<const DictType*>(output_type.get())->getValueType();
auto dt = output_type->expect<DictType>();
TypePtr key_type = dt->getKeyType();
TypePtr value_type = dt->getValueType();
return [=](Stack& stack) {
auto vals = c10::impl::GenericDict(key_type, value_type);
for (size_t i = 0; i < num_inputs; i += 2) {
Expand Down
68 changes: 68 additions & 0 deletions torch/csrc/jit/script/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,73 @@ void addFunctionToModule(Module& module, const StrongFunctionPtr& func) {
module.type()->addMethod(method);
}

// this is used in our test suite to check that we correctly preserved type tags
bool ivalue_tags_match(const Module& lhs, const Module& rhs) {
struct Work {
IValue a;
IValue b;
};
std::unordered_set<const void*> visited;
std::vector<Work> work = {{lhs.module_object(), rhs.module_object()}};
while (!work.empty()) {
Work item = work.back();
work.pop_back();
if (item.a.isPtrType()) {
// uncomment to debug type matching errors
// std::cout << "MATCHING " << /*item.a <<*/ "(" << *item.a.type() << ") "
// << item.a.internalToPointer() << " " << /*item.b <<*/ " ("
// << *item.b.type() << ") " << item.b.internalToPointer() <<
// "\n";

if (visited.count(item.a.internalToPointer())) {
continue;
}
visited.emplace(item.a.internalToPointer());
}
if (*unshapedType(item.a.type()) != *unshapedType(item.b.type())) {
return false;
}
// check tags for objects that contain subobjects
if (item.a.isObject()) {
auto ao = item.a.toObject();
auto bo = item.b.toObject();
for (size_t i = 0; i < ao->slots().size(); ++i) {
work.emplace_back(Work{ao->slots().at(i), bo->slots().at(i)});
}
} else if (item.a.isTuple()) {
auto at = item.a.toTuple();
auto bt = item.b.toTuple();
for (size_t i = 0; i < at->elements().size(); ++i) {
work.emplace_back(Work{at->elements().at(i), bt->elements().at(i)});
}
} else if (item.a.isGenericList()) {
auto al = item.a.toGenericList();
auto bl = item.b.toGenericList();
for (size_t i = 0; i < al.size(); ++i) {
work.emplace_back(Work{al.get(i), bl.get(i)});
}
} else if (item.a.isGenericDict()) {
auto ad = item.a.toGenericDict();
auto bd = item.b.toGenericDict();
for (auto& item : ad) {
// Dictionaory keys cannot contain List/Dicts that require tags
// so we do not have to check them.
// Furthermore without ordered dicts it is expensive to find the
// equivalent key
work.emplace_back(Work{item.value(), bd.at(item.key())});
}
} else if (item.a.isFuture()) {
auto af = item.a.toFuture();
auto bf = item.b.toFuture();
af->wait();
bf->wait();
work.emplace_back(Work{af->value(), bf->value()});
}
}

return true;
}

void initJitScriptBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();

Expand Down Expand Up @@ -872,6 +939,7 @@ void initJitScriptBindings(PyObject* module) {
auto fn = cu->create_function(std::move(name), graph);
return StrongFunctionPtr(std::move(cu), fn);
});
m.def("_ivalue_tags_match", ivalue_tags_match);

py::class_<testing::FileCheck>(m, "FileCheck")
.def(py::init<>())
Expand Down

0 comments on commit 12762cd

Please sign in to comment.