Skip to content

Commit

Permalink
Use concrete types on call sites for Dict/List (pytorch#22004)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#22004

In future, we want all dicts/lists to store information about the types they contain.
This is only possible if the creation API doesn't allow creating lists/dicts without type information.
This diff removes some call sites that don't specify type information and have it specify type information.

Reviewed By: dzhulgakov

Differential Revision: D15906387

fbshipit-source-id: 64766a2534b52c221e8a5501a85eaad13812e7bd
  • Loading branch information
smessmer authored and facebook-github-bot committed Jul 2, 2019
1 parent 693871d commit 6d58713
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 7 deletions.
1 change: 0 additions & 1 deletion aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ struct CAFFE2_API IValue final {
c10::ArrayRef<at::Tensor> toTensorListRef() const;

//GenericList
IValue(std::vector<IValue> v);
IValue(c10::List<IValue> v);
bool isGenericList() const { return Tag::GenericList == tag; }
c10::List<IValue> toGenericList() &&;
Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/core/ivalue_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,6 @@ inline IValue::IValue(c10::impl::GenericList v)
: tag(Tag::GenericList), is_intrusive_ptr(true) {
payload.as_intrusive_ptr = v.impl_.release();
}
inline IValue::IValue(std::vector<IValue> v)
: IValue(c10::impl::toList(std::move(v))) {}

template<class T> inline IValue::IValue(c10::List<T> v)
: IValue(impl::toGenericList<T>(std::move(v))) {
Expand All @@ -619,7 +617,7 @@ inline IValue::IValue(c10::Dict<Key, Value> v)
: IValue(impl::toGenericDict(std::move(v))) {}

template<class Key, class Value> inline IValue::IValue(std::unordered_map<Key, Value> v)
: IValue(impl::GenericDict()) {
: IValue(Dict<Key, Value>()) {
auto dict = to<c10::Dict<Key, Value>>();
dict.reserve(v.size());
for (auto& e : v) {
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/api/jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ TEST(TorchScriptTest, TestDictArgMatching) {
def dict_op(a: Dict[str, Tensor], b: str):
return a[b]
)JIT");
c10::impl::GenericDict dict;
c10::Dict<std::string, at::Tensor> dict;
dict.insert("hello", torch::ones({2}));
auto output = module->run_method("dict_op", dict, std::string("hello"));
ASSERT_EQ(1, output.toTensor()[0].item<int64_t>());
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/pybind_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ inline TypedStack toTypedStack(const py::tuple& inputs) {
}

inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {
c10::List<IValue> elems;
c10::impl::GenericList elems;
for (auto elem : obj) {
elems.push_back(toIValue(elem, elem_type));
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -966,7 +966,7 @@ RegisterOperators reg(
} else {
return [=](Stack& stack) {
const size_t stack_size = stack.size();
c10::List<IValue> vals;
c10::impl::GenericList vals;
vals.reserve(num_inputs);
for (size_t i = stack_size - num_inputs; i < stack_size; ++i) {
vals.emplace_back(std::move(stack[i]));
Expand Down

0 comments on commit 6d58713

Please sign in to comment.