Skip to content

Commit

Permalink
[PyTorch Edge][type] Add type support for NamedTuple custom class (im…
Browse files Browse the repository at this point in the history
…port) (pytorch#63130)

Summary:
Pull Request resolved: pytorch#63130

Extend `type_parser` to handle `NamedTuple` type. It can be extended to handle other types when needed. The custom type will follow the following format:
```
"qualified_named[
    NamedTuple, [
        [filed_name_1, field_type_1],
        [filed_name_2, field_type_2]
    ]
]"
```
For example:
```
"__torch__.base_models.sparse_nn.pytorch_preproc_types.PreprocOutputType[
    NamedTuple, [
        [float_features, Tensor],
        [id_list_features, List[Tensor]],
        [label,  Tensor],
        [weight, Tensor],
        ]
    ]"
```

For nested types, the order of type lists from type table should be:
```
std::string type_1 = “__torch__.C [
    NamedTuple, [
        [field_name_c_1, Tensor],
        [field_name_c_2, Tuple[Tensor, Tensor]],
    ]
]”

std::string type_2 = “__torch__.B [
   NamedTuple, [
       [field_name_b, __torch__.C ]
   ]
]”

std::string type_3 = “__torch__.A[
   NamedTuple, [
       [field_name_a, __torch__.B]
   ]
]”
std::vector<std::string> type_strs = {type_str_1, type_str_2, type_3};
std::vector<TypePtr> type_ptrs =  c10::parseType(type_strs);
```

namedtuple from both `collection` and `typing` are supported
```

from typing import NamedTuple
from collections import namedtuple
```

This change only adds the parser and now new runtime can read the above format.
ghstack-source-id: 141293658

Test Plan:
```
buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.CompatiblePrimitiveType'
buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.CompatibleCustomType'

buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.InCompatiblePrimitiveType'
buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.InCompatibleCustomType'
```

Reviewed By: iseeyuan

Differential Revision: D30261547

fbshipit-source-id: 68a9974338464e320b39a5c613dc048f6c5adeb5
  • Loading branch information
cccclai authored and facebook-github-bot committed Oct 22, 2021
1 parent d3fc3c4 commit 5f58764
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 40 deletions.
4 changes: 2 additions & 2 deletions test/cpp/jit/test_lite_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ TEST(LiteInterpreterTest, isCompatibleSuccess) {
std::unordered_map<std::string, OperatorInfo> model_ops;
model_ops["aten::add.Scalar"] = OperatorInfo{2};

std::unordered_set<std::string> types = {"List", "int"};
std::unordered_set<std::string> types = {"List", "int", "NamedTuple"};
auto model_info = ModelCompatibilityInfo{
caffe2::serialize::kMaxSupportedBytecodeVersion, model_ops, types};

Expand Down Expand Up @@ -688,7 +688,7 @@ TEST(LiteInterpreterTest, isCompatibleFail) {

// test trivial failure due to type
runtime_info = RuntimeCompatibilityInfo::get();
std::unordered_set<std::string> types = {"List", "int", "NamedTuple"};
std::unordered_set<std::string> types = {"List", "int", "Sequence"};

model_info = ModelCompatibilityInfo{
caffe2::serialize::kMaxSupportedBytecodeVersion, model_ops, types};
Expand Down
149 changes: 143 additions & 6 deletions test/cpp/jit/test_mobile_type_parser.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
#include <gtest/gtest.h>
#include <test/cpp/jit/test_utils.h>

#include <ATen/core/jit_type.h>

namespace c10 {
// std::string serializeType(const Type &t);
TypePtr parseType(const std::string& pythonStr);
std::vector<TypePtr> parseType(std::vector<std::string>& pythonStr);
} // namespace c10

namespace torch {
namespace jit {
TEST(MobileTypeParserTest, Empty) {
std::string empty_ps("");
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
ASSERT_ANY_THROW(c10::parseType(empty_ps));
}

// Parse Success cases
TEST(MobileTypeParserTest, RoundTripAnnotationStr) {
std::string int_ps("int");
auto int_tp = c10::parseType(int_ps);
Expand All @@ -30,6 +27,20 @@ TEST(MobileTypeParserTest, NestedContainersAnnotationStr) {
ASSERT_EQ(tuple_ps, tuple_tps);
}

TEST(MobileTypeParserTest, TorchBindClass) {
std::string tuple_ps("__torch__.torch.classes.rnn.CellParamsBase");
auto tuple_tp = c10::parseType(tuple_ps);
std::string tuple_tps = tuple_tp->annotation_str();
ASSERT_EQ(tuple_ps, tuple_tps);
}

TEST(MobileTypeParserTest, ListOfTorchBindClass) {
std::string tuple_ps("List[__torch__.torch.classes.rnn.CellParamsBase]");
auto tuple_tp = c10::parseType(tuple_ps);
std::string tuple_tps = tuple_tp->annotation_str();
ASSERT_EQ(tuple_ps, tuple_tps);
}

TEST(MobileTypeParserTest, NestedContainersAnnotationStrWithSpaces) {
std::string tuple_ps(
"Tuple[str, Optional[float], Dict[str, List[Tensor]], int]");
Expand All @@ -41,6 +52,110 @@ TEST(MobileTypeParserTest, NestedContainersAnnotationStrWithSpaces) {
ASSERT_EQ(tuple_ps, tuple_space_tps);
}

TEST(MobileTypeParserTest, NamedTuple) {
std::string named_tuple_ps(
"__torch__.base_models.preproc_types.PreprocOutputType["
" NamedTuple, ["
" [float_features, Tensor],"
" [id_list_features, List[Tensor]],"
" [label, Tensor],"
" [weight, Tensor],"
" [prod_prediction, Tuple[Tensor, Tensor]],"
" [id_score_list_features, List[Tensor]],"
" [embedding_features, List[Tensor]],"
" [teacher_label, Tensor]"
" ]"
" ]");

c10::TypePtr named_tuple_tp = c10::parseType(named_tuple_ps);
std::string named_tuple_annotation_str = named_tuple_tp->annotation_str();
ASSERT_EQ(
named_tuple_annotation_str,
"__torch__.base_models.preproc_types.PreprocOutputType");
}

TEST(MobileTypeParserTest, DictNestedNamedTupleTypeList) {
std::string type_str_1(
"__torch__.base_models.preproc_types.PreprocOutputType["
" NamedTuple, ["
" [float_features, Tensor],"
" [id_list_features, List[Tensor]],"
" [label, Tensor],"
" [weight, Tensor],"
" [prod_prediction, Tuple[Tensor, Tensor]],"
" [id_score_list_features, List[Tensor]],"
" [embedding_features, List[Tensor]],"
" [teacher_label, Tensor]"
" ]");
std::string type_str_2(
"Dict[str, __torch__.base_models.preproc_types.PreprocOutputType]");
std::vector<std::string> type_strs = {type_str_1, type_str_2};
std::vector<c10::TypePtr> named_tuple_tps = c10::parseType(type_strs);
std::string named_tuple_annotation_str = named_tuple_tps[1]->annotation_str();
ASSERT_EQ(
named_tuple_annotation_str,
"Dict[str, __torch__.base_models.preproc_types.PreprocOutputType]");
}

TEST(MobileTypeParserTest, NamedTupleNestedNamedTupleTypeList) {
std::string type_str_1(
" __torch__.ccc.xxx ["
" NamedTuple, ["
" [field_name_c_1, Tensor],"
" [field_name_c_2, Tuple[Tensor, Tensor]]"
" ]"
"]");
std::string type_str_2(
"__torch__.bbb.xxx ["
" NamedTuple,["
" [field_name_b, __torch__.ccc.xxx]]"
" ]"
"]");

std::string type_str_3(
"__torch__.aaa.xxx["
" NamedTuple, ["
" [field_name_a, __torch__.bbb.xxx]"
" ]"
"]");

std::vector<std::string> type_strs = {type_str_1, type_str_2, type_str_3};
std::vector<c10::TypePtr> named_tuple_tps = c10::parseType(type_strs);
std::string named_tuple_annotation_str = named_tuple_tps[2]->annotation_str();
ASSERT_EQ(named_tuple_annotation_str, "__torch__.aaa.xxx");
}

TEST(MobileTypeParserTest, NamedTupleNestedNamedTuple) {
std::string named_tuple_ps(
"__torch__.aaa.xxx["
" NamedTuple, ["
" [field_name_a, __torch__.bbb.xxx ["
" NamedTuple, ["
" [field_name_b, __torch__.ccc.xxx ["
" NamedTuple, ["
" [field_name_c_1, Tensor],"
" [field_name_c_2, Tuple[Tensor, Tensor]]"
" ]"
" ]"
" ]"
" ]"
" ]"
" ]"
" ] "
"]");

c10::TypePtr named_tuple_tp = c10::parseType(named_tuple_ps);
std::string named_tuple_annotation_str = named_tuple_tp->str();
ASSERT_EQ(named_tuple_annotation_str, "__torch__.aaa.xxx");
}

// Parse throw cases
TEST(MobileTypeParserTest, Empty) {
std::string empty_ps("");
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
ASSERT_ANY_THROW(c10::parseType(empty_ps));
}

TEST(MobileTypeParserTest, TypoRaises) {
std::string typo_token("List[tensor]");
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
Expand Down Expand Up @@ -89,5 +204,27 @@ TEST(MobileTypeParserTest, NonIdentifierRaises) {
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
ASSERT_ANY_THROW(c10::parseType(non_id));
}

TEST(MobileTypeParserTest, DictNestedNamedTupleTypeListRaises) {
std::string type_str_1(
"Dict[str, __torch__.base_models.preproc_types.PreprocOutputType]");
std::string type_str_2(
"__torch__.base_models.preproc_types.PreprocOutputType["
" NamedTuple, ["
" [float_features, Tensor],"
" [id_list_features, List[Tensor]],"
" [label, Tensor],"
" [weight, Tensor],"
" [prod_prediction, Tuple[Tensor, Tensor]],"
" [id_score_list_features, List[Tensor]],"
" [embedding_features, List[Tensor]],"
" [teacher_label, Tensor]"
" ]");
std::vector<std::string> type_strs = {type_str_1, type_str_2};
std::string error_message =
R"(Can't find definition for the type: __torch__.base_models.preproc_types.PreprocOutputType)";
ASSERT_THROWS_WITH_MESSAGE(c10::parseType(type_strs), error_message);
}

} // namespace jit
} // namespace torch
16 changes: 6 additions & 10 deletions torch/csrc/jit/mobile/model_compatibility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,22 +231,18 @@ std::unordered_set<std::string> _get_mobile_model_contained_types(
method_tuple.at(1).toTuple()->elements()[BYTECODE_INDEX_TYPE];
const auto& type_table =
type_table_tuple.toTuple()->elements()[1].toTuple()->elements();

// type_table is a list of IValue, and each IValue is a string,
// for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
std::vector<std::string> type_name_list;
for (const auto& type_definition : type_table) {
std::unordered_set<std::string> type_tokens;
std::string type_name = type_definition.toString()->string();

// parse the type only if it's new, and insert it in the record
if (parsed_type_names_records.find(type_name) ==
parsed_type_names_records.end()) {
parsed_type_names_records.insert(type_name);
at::TypeParser parser(type_name);
parser.parse();
type_tokens = parser.getContainedTypes();
contained_types.insert(type_tokens.begin(), type_tokens.end());
}
type_name_list.emplace_back(type_name);
}
at::TypeParser parser(type_name_list);
parser.parseList();
contained_types = parser.getContainedTypes();
}

return contained_types;
Expand Down
24 changes: 9 additions & 15 deletions torch/csrc/jit/mobile/parse_bytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,24 +124,18 @@ void parseConstants(
function->append_constant(constant);
}
}

void parseTypes(
const c10::ivalue::TupleElements& types_list,
mobile::Function* function) {
static const c10::QualifiedName classPrefix = "__torch__.torch.classes";
for (const auto& t : types_list) {
c10::QualifiedName qn(t.toStringRef());
if (classPrefix.isPrefixOf(qn)) {
auto classType = getCustomClass(qn.qualifiedName());
TORCH_CHECK(
classType,
"The implementation of class ",
qn.qualifiedName(),
" cannot be found.");
function->append_type(classType);
} else {
function->append_type(c10::parseType(t.toStringRef()));
}
std::vector<std::string> types_string_list;
types_string_list.resize(types_list.size());
for (size_t i = 0; i < types_list.size(); i++) {
types_string_list[i] = types_list[i].toString()->string();
}

std::vector<c10::TypePtr> types_ptr_list = c10::parseType(types_string_list);
for (auto& type_ptr : types_ptr_list) {
function->append_type(type_ptr);
}
}

Expand Down
Loading

0 comments on commit 5f58764

Please sign in to comment.