From 5f58764d1db41189a22bf94788e03889c4c3559a Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Fri, 22 Oct 2021 00:39:29 -0700 Subject: [PATCH] [PyTorch Edge][type] Add type support for NamedTuple custom class (import) (#63130) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63130 Extend `type_parser` to handle `NamedTuple` type. It can be extended to handle other types when needed. The custom type will follow the following format: ``` "qualified_named[ NamedTuple, [ [filed_name_1, field_type_1], [filed_name_2, field_type_2] ] ]" ``` For example: ``` "__torch__.base_models.sparse_nn.pytorch_preproc_types.PreprocOutputType[ NamedTuple, [ [float_features, Tensor], [id_list_features, List[Tensor]], [label, Tensor], [weight, Tensor], ] ]" ``` For nested types, the order of type lists from type table should be: ``` std::string type_1 = “__torch__.C [ NamedTuple, [ [field_name_c_1, Tensor], [field_name_c_2, Tuple[Tensor, Tensor]], ] ]” std::string type_2 = “__torch__.B [ NamedTuple, [ [field_name_b, __torch__.C ] ] ]” std::string type_3 = “__torch__.A[ NamedTuple, [ [field_name_a, __torch__.B] ] ]” std::vector type_strs = {type_str_1, type_str_2, type_3}; std::vector type_ptrs = c10::parseType(type_strs); ``` namedtuple from both `collection` and `typing` are supported ``` from typing import NamedTuple from collections import namedtuple ``` This change only adds the parser and now new runtime can read the above format. ghstack-source-id: 141293658 Test Plan: ``` buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.CompatiblePrimitiveType' buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.CompatibleCustomType' buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.InCompatiblePrimitiveType' buck test mode/dev //caffe2/test/cpp/jit:jit -- --exact 'caffe2/test/cpp/jit:jit - LiteInterpreterTest.InCompatibleCustomType' ``` Reviewed By: iseeyuan Differential Revision: D30261547 fbshipit-source-id: 68a9974338464e320b39a5c613dc048f6c5adeb5 --- test/cpp/jit/test_lite_interpreter.cpp | 4 +- test/cpp/jit/test_mobile_type_parser.cpp | 149 +++++++++++++++++- torch/csrc/jit/mobile/model_compatibility.cpp | 16 +- torch/csrc/jit/mobile/parse_bytecode.cpp | 24 ++- torch/csrc/jit/mobile/type_parser.cpp | 149 +++++++++++++++++- torch/csrc/jit/mobile/type_parser.h | 10 ++ 6 files changed, 312 insertions(+), 40 deletions(-) diff --git a/test/cpp/jit/test_lite_interpreter.cpp b/test/cpp/jit/test_lite_interpreter.cpp index f29396e69cd5fe..82d3cba61ba354 100644 --- a/test/cpp/jit/test_lite_interpreter.cpp +++ b/test/cpp/jit/test_lite_interpreter.cpp @@ -646,7 +646,7 @@ TEST(LiteInterpreterTest, isCompatibleSuccess) { std::unordered_map model_ops; model_ops["aten::add.Scalar"] = OperatorInfo{2}; - std::unordered_set types = {"List", "int"}; + std::unordered_set types = {"List", "int", "NamedTuple"}; auto model_info = ModelCompatibilityInfo{ caffe2::serialize::kMaxSupportedBytecodeVersion, model_ops, types}; @@ -688,7 +688,7 @@ TEST(LiteInterpreterTest, isCompatibleFail) { // test trivial failure due to type runtime_info = RuntimeCompatibilityInfo::get(); - std::unordered_set types = {"List", "int", "NamedTuple"}; + std::unordered_set types = {"List", "int", "Sequence"}; model_info = ModelCompatibilityInfo{ caffe2::serialize::kMaxSupportedBytecodeVersion, model_ops, types}; diff --git a/test/cpp/jit/test_mobile_type_parser.cpp b/test/cpp/jit/test_mobile_type_parser.cpp index ea051dac8ce28b..9835b0928e2dc6 100644 --- a/test/cpp/jit/test_mobile_type_parser.cpp +++ b/test/cpp/jit/test_mobile_type_parser.cpp @@ -1,20 +1,17 @@ #include +#include #include namespace c10 { -// std::string serializeType(const Type &t); TypePtr parseType(const std::string& pythonStr); +std::vector parseType(std::vector& pythonStr); } // namespace c10 namespace torch { namespace jit { -TEST(MobileTypeParserTest, Empty) { - std::string empty_ps(""); - // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) - ASSERT_ANY_THROW(c10::parseType(empty_ps)); -} +// Parse Success cases TEST(MobileTypeParserTest, RoundTripAnnotationStr) { std::string int_ps("int"); auto int_tp = c10::parseType(int_ps); @@ -30,6 +27,20 @@ TEST(MobileTypeParserTest, NestedContainersAnnotationStr) { ASSERT_EQ(tuple_ps, tuple_tps); } +TEST(MobileTypeParserTest, TorchBindClass) { + std::string tuple_ps("__torch__.torch.classes.rnn.CellParamsBase"); + auto tuple_tp = c10::parseType(tuple_ps); + std::string tuple_tps = tuple_tp->annotation_str(); + ASSERT_EQ(tuple_ps, tuple_tps); +} + +TEST(MobileTypeParserTest, ListOfTorchBindClass) { + std::string tuple_ps("List[__torch__.torch.classes.rnn.CellParamsBase]"); + auto tuple_tp = c10::parseType(tuple_ps); + std::string tuple_tps = tuple_tp->annotation_str(); + ASSERT_EQ(tuple_ps, tuple_tps); +} + TEST(MobileTypeParserTest, NestedContainersAnnotationStrWithSpaces) { std::string tuple_ps( "Tuple[str, Optional[float], Dict[str, List[Tensor]], int]"); @@ -41,6 +52,110 @@ TEST(MobileTypeParserTest, NestedContainersAnnotationStrWithSpaces) { ASSERT_EQ(tuple_ps, tuple_space_tps); } +TEST(MobileTypeParserTest, NamedTuple) { + std::string named_tuple_ps( + "__torch__.base_models.preproc_types.PreprocOutputType[" + " NamedTuple, [" + " [float_features, Tensor]," + " [id_list_features, List[Tensor]]," + " [label, Tensor]," + " [weight, Tensor]," + " [prod_prediction, Tuple[Tensor, Tensor]]," + " [id_score_list_features, List[Tensor]]," + " [embedding_features, List[Tensor]]," + " [teacher_label, Tensor]" + " ]" + " ]"); + + c10::TypePtr named_tuple_tp = c10::parseType(named_tuple_ps); + std::string named_tuple_annotation_str = named_tuple_tp->annotation_str(); + ASSERT_EQ( + named_tuple_annotation_str, + "__torch__.base_models.preproc_types.PreprocOutputType"); +} + +TEST(MobileTypeParserTest, DictNestedNamedTupleTypeList) { + std::string type_str_1( + "__torch__.base_models.preproc_types.PreprocOutputType[" + " NamedTuple, [" + " [float_features, Tensor]," + " [id_list_features, List[Tensor]]," + " [label, Tensor]," + " [weight, Tensor]," + " [prod_prediction, Tuple[Tensor, Tensor]]," + " [id_score_list_features, List[Tensor]]," + " [embedding_features, List[Tensor]]," + " [teacher_label, Tensor]" + " ]"); + std::string type_str_2( + "Dict[str, __torch__.base_models.preproc_types.PreprocOutputType]"); + std::vector type_strs = {type_str_1, type_str_2}; + std::vector named_tuple_tps = c10::parseType(type_strs); + std::string named_tuple_annotation_str = named_tuple_tps[1]->annotation_str(); + ASSERT_EQ( + named_tuple_annotation_str, + "Dict[str, __torch__.base_models.preproc_types.PreprocOutputType]"); +} + +TEST(MobileTypeParserTest, NamedTupleNestedNamedTupleTypeList) { + std::string type_str_1( + " __torch__.ccc.xxx [" + " NamedTuple, [" + " [field_name_c_1, Tensor]," + " [field_name_c_2, Tuple[Tensor, Tensor]]" + " ]" + "]"); + std::string type_str_2( + "__torch__.bbb.xxx [" + " NamedTuple,[" + " [field_name_b, __torch__.ccc.xxx]]" + " ]" + "]"); + + std::string type_str_3( + "__torch__.aaa.xxx[" + " NamedTuple, [" + " [field_name_a, __torch__.bbb.xxx]" + " ]" + "]"); + + std::vector type_strs = {type_str_1, type_str_2, type_str_3}; + std::vector named_tuple_tps = c10::parseType(type_strs); + std::string named_tuple_annotation_str = named_tuple_tps[2]->annotation_str(); + ASSERT_EQ(named_tuple_annotation_str, "__torch__.aaa.xxx"); +} + +TEST(MobileTypeParserTest, NamedTupleNestedNamedTuple) { + std::string named_tuple_ps( + "__torch__.aaa.xxx[" + " NamedTuple, [" + " [field_name_a, __torch__.bbb.xxx [" + " NamedTuple, [" + " [field_name_b, __torch__.ccc.xxx [" + " NamedTuple, [" + " [field_name_c_1, Tensor]," + " [field_name_c_2, Tuple[Tensor, Tensor]]" + " ]" + " ]" + " ]" + " ]" + " ]" + " ]" + " ] " + "]"); + + c10::TypePtr named_tuple_tp = c10::parseType(named_tuple_ps); + std::string named_tuple_annotation_str = named_tuple_tp->str(); + ASSERT_EQ(named_tuple_annotation_str, "__torch__.aaa.xxx"); +} + +// Parse throw cases +TEST(MobileTypeParserTest, Empty) { + std::string empty_ps(""); + // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) + ASSERT_ANY_THROW(c10::parseType(empty_ps)); +} + TEST(MobileTypeParserTest, TypoRaises) { std::string typo_token("List[tensor]"); // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) @@ -89,5 +204,27 @@ TEST(MobileTypeParserTest, NonIdentifierRaises) { // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto) ASSERT_ANY_THROW(c10::parseType(non_id)); } + +TEST(MobileTypeParserTest, DictNestedNamedTupleTypeListRaises) { + std::string type_str_1( + "Dict[str, __torch__.base_models.preproc_types.PreprocOutputType]"); + std::string type_str_2( + "__torch__.base_models.preproc_types.PreprocOutputType[" + " NamedTuple, [" + " [float_features, Tensor]," + " [id_list_features, List[Tensor]]," + " [label, Tensor]," + " [weight, Tensor]," + " [prod_prediction, Tuple[Tensor, Tensor]]," + " [id_score_list_features, List[Tensor]]," + " [embedding_features, List[Tensor]]," + " [teacher_label, Tensor]" + " ]"); + std::vector type_strs = {type_str_1, type_str_2}; + std::string error_message = + R"(Can't find definition for the type: __torch__.base_models.preproc_types.PreprocOutputType)"; + ASSERT_THROWS_WITH_MESSAGE(c10::parseType(type_strs), error_message); +} + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/mobile/model_compatibility.cpp b/torch/csrc/jit/mobile/model_compatibility.cpp index 04ffad6ca4a132..8e14efbf3a3da1 100644 --- a/torch/csrc/jit/mobile/model_compatibility.cpp +++ b/torch/csrc/jit/mobile/model_compatibility.cpp @@ -231,22 +231,18 @@ std::unordered_set _get_mobile_model_contained_types( method_tuple.at(1).toTuple()->elements()[BYTECODE_INDEX_TYPE]; const auto& type_table = type_table_tuple.toTuple()->elements()[1].toTuple()->elements(); + // type_table is a list of IValue, and each IValue is a string, // for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]" + std::vector type_name_list; for (const auto& type_definition : type_table) { std::unordered_set type_tokens; std::string type_name = type_definition.toString()->string(); - - // parse the type only if it's new, and insert it in the record - if (parsed_type_names_records.find(type_name) == - parsed_type_names_records.end()) { - parsed_type_names_records.insert(type_name); - at::TypeParser parser(type_name); - parser.parse(); - type_tokens = parser.getContainedTypes(); - contained_types.insert(type_tokens.begin(), type_tokens.end()); - } + type_name_list.emplace_back(type_name); } + at::TypeParser parser(type_name_list); + parser.parseList(); + contained_types = parser.getContainedTypes(); } return contained_types; diff --git a/torch/csrc/jit/mobile/parse_bytecode.cpp b/torch/csrc/jit/mobile/parse_bytecode.cpp index 10f416e31dd827..8c5acb583b424c 100644 --- a/torch/csrc/jit/mobile/parse_bytecode.cpp +++ b/torch/csrc/jit/mobile/parse_bytecode.cpp @@ -124,24 +124,18 @@ void parseConstants( function->append_constant(constant); } } - void parseTypes( const c10::ivalue::TupleElements& types_list, mobile::Function* function) { - static const c10::QualifiedName classPrefix = "__torch__.torch.classes"; - for (const auto& t : types_list) { - c10::QualifiedName qn(t.toStringRef()); - if (classPrefix.isPrefixOf(qn)) { - auto classType = getCustomClass(qn.qualifiedName()); - TORCH_CHECK( - classType, - "The implementation of class ", - qn.qualifiedName(), - " cannot be found."); - function->append_type(classType); - } else { - function->append_type(c10::parseType(t.toStringRef())); - } + std::vector types_string_list; + types_string_list.resize(types_list.size()); + for (size_t i = 0; i < types_list.size(); i++) { + types_string_list[i] = types_list[i].toString()->string(); + } + + std::vector types_ptr_list = c10::parseType(types_string_list); + for (auto& type_ptr : types_ptr_list) { + function->append_type(type_ptr); } } diff --git a/torch/csrc/jit/mobile/type_parser.cpp b/torch/csrc/jit/mobile/type_parser.cpp index 3210773cb4acd9..ee5787d1216a95 100644 --- a/torch/csrc/jit/mobile/type_parser.cpp +++ b/torch/csrc/jit/mobile/type_parser.cpp @@ -16,12 +16,14 @@ using torch::jit::string_to_type_lut; using torch::jit::valid_single_char_tokens; namespace c10 { + namespace { // Torchbind custom class always starts with the follow prefix, so use it as an // identifier for torchbind custom class type static constexpr const char* kTypeTorchbindCustomClass = "__torch__.torch.classes"; +static constexpr const char* kTypeNamedTuple = "NamedTuple"; bool isSpecialChar(char a) { for (const char* c = valid_single_char_tokens; *c; c++) { @@ -37,6 +39,42 @@ TypeParser::TypeParser(std::string pythonStr) lex(); } +TypeParser::TypeParser(std::vector& pythonStrs) + : start_(0), pythonStrs_(pythonStrs) {} + +// For the Python string list parsing, the order of the Python string matters. +// In bytecode, the order of the type list correspondings to the order of +// instruction. In nested type, the lowest level type will be at the beginning +// of the type list. It is possible to parse it without worrying about +// ordering, but it also introduces 1) extra cost to process nested type to +// the correct order 2) lost the benifit that the instruction order is likely +// problematic if type list parsing fails. +std::vector TypeParser::parseList() { + std::vector typePtrs; + typePtrs.resize(pythonStrs_.size()); + static const c10::QualifiedName classPrefix = "__torch__.torch.classes"; + for (size_t i = 0; i < pythonStrs_.size(); i++) { + c10::QualifiedName qn(pythonStrs_[i]); + c10::TypePtr type_ptr; + if (classPrefix.isPrefixOf(qn)) { + type_ptr = torch::getCustomClass(qn.qualifiedName()); + TORCH_CHECK( + type_ptr, + "The implementation of class ", + qn.qualifiedName(), + " cannot be found."); + } else { + pythonStr_ = pythonStrs_[i]; + start_ = 0; + lex(); + type_ptr = parse(); + } + typePtrs[i] = type_ptr; + str_type_ptr_map_[type_ptr->repr_str()] = type_ptr; + } + return typePtrs; +} + // The list of non-simple types supported by currrent parser. std::unordered_set TypeParser::getNonSimpleType() { static std::unordered_set nonSimpleTypes{ @@ -47,7 +85,7 @@ std::unordered_set TypeParser::getNonSimpleType() { // The list of custom types supported by currrent parser. std::unordered_set TypeParser::getCustomType() { static std::unordered_set customeTypes{ - kTypeTorchbindCustomClass}; + kTypeTorchbindCustomClass, kTypeNamedTuple}; return customeTypes; } @@ -102,7 +140,14 @@ TypePtr TypeParser::parse() { contained_types_.insert(token); return parseNonSimple(token); } else if (token == "__torch__") { - return parseTorchbindClassType(); + expectChar('.'); + if (cur() == "torch") { + // torch bind class starts with __torch__.torch.classes + return parseTorchbindClassType(); + } else { + // other class starts with __torch__ following by custom names + return parseCustomType(); + } } else { TORCH_CHECK( false, @@ -114,13 +159,98 @@ TypePtr TypeParser::parse() { return nullptr; } +// NamedTuple custom type will be following structure: +// "qualified_named[ +// NamedTuple, [ +// [filed_name_1, field_type_1], +// [filed_name_2, field_type_2] +// ] +// ]" +// Example NamedTuple type: +// "__torch__.base_models.sparse_nn.pytorch_preproc_types.PreprocOutputType[ +// NamedTuple, [ +// [float_features, Tensor], +// [id_list_features, List[Tensor]], +// [label, Tensor], +// [weight, Tensor], +// ] +// ]" +TypePtr TypeParser::parseNamedTuple(const std::string& qualified_name) { + std::vector field_names; + std::vector field_types; + std::string ns; + expect(","); + expect("["); + while (cur() != "]") { + expect("["); + std::string field_name = next(); + expect(","); + TypePtr field_type = parse(); + field_names.emplace_back(field_name); + field_types.emplace_back(field_type); + expect("]"); + if (cur() == ",") { + next(); + } + } + return TupleType::createNamed(qualified_name, field_names, field_types); +} + +// Custom type will be following structure: +// "qualified_named[ +// custom_type, [ +// [filed_name_1, field_type_1], +// [filed_name_2, field_type_2] +// ] +// ]" +TypePtr TypeParser::parseCustomType() { + c10::string_view token = cur(); + std::string qualified_name = "__torch__."; + qualified_name.reserve(qualified_name.size() + token.size()); + qualified_name.append(token.begin(), token.end()); + next(); + while (cur() == ".") { + qualified_name.append(next()); + qualified_name.append(next()); + } + // After cur() moves to the next token after qualified name, if it's "[", it + // means this custom type follow by it's class definition. Otherwise, it's a + // barebone qualified name and needs to look up str_type_ptr_map_ to find + // the typeptr. + if (cur() == "[") { + next(); + std::string type_name = next(); + // Currently only supports NamedTuple custom type, if more types need to + // be supported, extend them here. + if (type_name == kTypeNamedTuple) { + contained_types_.insert(kTypeNamedTuple); + return parseNamedTuple(qualified_name); + } else { + TORCH_CHECK( + false, "Custom Type ", type_name, " is not supported in the parser."); + } + } else { + auto find_type = str_type_ptr_map_.find(qualified_name); + if (find_type != str_type_ptr_map_.end()) { + return find_type->second; + } else { + // When the type definition can't be found, likely two reasons + // 1. The type list in bytecode.pkl is not in the correct order + // 2. This custom type definition doesn't exist in bytecode.pkl type + // table + TORCH_CHECK( + false, "Can't find definition for the type: ", qualified_name); + } + return nullptr; + } +} + TypePtr TypeParser::parseTorchbindClassType() { - static constexpr std::array expected_atoms = { - ".", "torch", ".", "classes", "."}; + static constexpr std::array expected_atoms = { + "torch", ".", "classes", "."}; for (const auto& atom : expected_atoms) { expect(atom); } - std::string ns = next(); expectChar('.'); std::string classname = next(); @@ -208,9 +338,14 @@ C10_NODISCARD c10::string_view TypeParser::cur() const { return next_token_; } -TORCH_API TypePtr parseType(const std::string& pythonStr) { - TypeParser parser(pythonStr); +TORCH_API at::TypePtr parseType(const std::string& pythonStr) { + at::TypeParser parser(pythonStr); return parser.parse(); } +TORCH_API std::vector parseType( + std::vector& pythonStrs) { + at::TypeParser parser(pythonStrs); + return parser.parseList(); +} } // namespace c10 diff --git a/torch/csrc/jit/mobile/type_parser.h b/torch/csrc/jit/mobile/type_parser.h index 056f396f58da28..a77b78f37707d1 100644 --- a/torch/csrc/jit/mobile/type_parser.h +++ b/torch/csrc/jit/mobile/type_parser.h @@ -4,13 +4,17 @@ namespace c10 { class TypeParser { public: explicit TypeParser(std::string pythonStr); + explicit TypeParser(std::vector& pythonStrs); TypePtr parse(); + std::vector parseList(); static std::unordered_set getNonSimpleType(); static std::unordered_set getCustomType(); std::unordered_set getContainedTypes(); private: + TypePtr parseNamedTuple(const std::string& qualified_name); + TypePtr parseCustomType(); TypePtr parseTorchbindClassType(); TypePtr parseNonSimple(const std::string& token); @@ -29,9 +33,15 @@ class TypeParser { size_t start_; c10::string_view next_token_; + // Used for parsing string list + std::vector pythonStrs_; + std::unordered_map str_type_ptr_map_; + // Store all contained types when parsing a string std::unordered_set contained_types_; }; TORCH_API TypePtr parseType(const std::string& pythonStr); + +TORCH_API std::vector parseType(std::vector& pythonStr); } // namespace c10