Skip to content

Commit

Permalink
Fixing few bugs in torch flatbuffer (pytorch#72349)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#72349

1. Interface call'd methods need to be registered to class. Previously all interface calls are inlined so  there was no such problem.
2. parseDoubleList and parseBoolList got reversed when refactoring.

Test Plan:
1. Get ASR's test model at
```
mkdir ~/asr1 && cd ~/asr1
fbpkg fetch speech.tuna.milan.ondevice.en_us
```
2. Convert model:
```
cd ~/fbsource
buck run //xplat/caffe2/fb/lite_predictor:convert_model -- --model=$HOME/asr1/pytorchmodel.pt --output_name=$HOME/asr1/pytorchmodel.ff
```
3. Ran lite_predictor_flatbuffer
```
 buck run //xplat/caffe2/fb/lite_predictor:lite_predictor_flatbuffer -- --model=$HOME/asr1/pytorchmodel.ff --method_to_call=encode_src --method_to_generate_input=get_all_bundled_inputs_for_encode_src

```

See perf metric generated (means loading and inference succeeded).

Reviewed By: gmagogsfm, zhxchen17

Differential Revision: D33959746

fbshipit-source-id: 24671e1189438119f477032eb6c29bd7736e74ca
(cherry picked from commit 5e18809)
  • Loading branch information
qihqi authored and pytorchmergebot committed Feb 5, 2022
1 parent f2f40ce commit 57f039b
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
22 changes: 22 additions & 0 deletions test/cpp/jit/test_flatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,28 @@ TEST(FlatbufferTest, OperatorSize1) {
func2.get_code().operators_.size());
}

TEST(FlatbufferTest, BoolAndDoubleList) {
Module m("m");
c10::List<bool> boollist;
boollist.push_back(false);
IValue boollist_ival = boollist;
IValue doublelist = std::vector<double>{2.0};
m.register_attribute("bool_list", boollist_ival.type(), boollist_ival);
m.register_attribute("double_list", doublelist.type(), doublelist);

CompilationOptions options;
mobile::Module bc = jitModuleToMobile(m, options);
auto buff = save_mobile_module_to_bytes(bc);
mobile::Module bc2 = parse_mobile_module(buff.data(), buff.size());

// if the variables read are wrong type the conversion will raise exception
auto boolval = bc2.attr("bool_list", {}).toBoolList().get(0);
auto doubleval = bc2.attr("double_list", {}).toDoubleList().get(0);

ASSERT_EQ(boolval, false);
ASSERT_EQ(doubleval, 2.0);
}

TEST(FlatbufferTest, OperatorTest2) { // NOLINT (use =delete in gtest)
const std::vector<std::string> test_programs{
// test invoking a method with default parameter
Expand Down
14 changes: 11 additions & 3 deletions torch/csrc/jit/mobile/flatbuffer_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,16 @@ mobile::Module FlatbufferLoader::parseModule(
all_ivalues_[i] = parseIValue(ival);
}
}

IValue& module_ivalue = getIValue(module->state_obj());

// register functions
for (const auto& f : all_functions_) {
uint32_t class_index =
ivalues->Get(f.first)->val_as_Function()->class_type();
ClassTypePtr class_type = all_types_[class_index];
class_type->addMethod(f.second);
}

return mobile::Module(module_ivalue.toObject(), mcu_);
}

Expand Down Expand Up @@ -368,14 +376,14 @@ IValue parseIntList(
return parseListNative<int64_t>(list);
}

IValue parseBoolList(
IValue parseDoubleList(
FlatbufferLoader&,
const mobile::serialization::IValue& ivalue) {
const auto& list = ivalue.val_as_DoubleList();
return parseListNative<double>(list);
}

IValue parseDoubleList(
IValue parseBoolList(
FlatbufferLoader&,
const mobile::serialization::IValue& ivalue) {
const auto& list = ivalue.val_as_BoolList();
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/jit/serialization/flatbuffer_serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ flatbuffers::Offset<mobile::serialization::Function> FlatbufferSerializer::
};

flatbuffers::Offset<mobile::serialization::Schema> schema_offset = 0;
uint32_t class_index = 0;
if (func.hasSchema()) {
const auto& schema = func.getSchema();
TORCH_CHECK(
Expand All @@ -249,14 +250,13 @@ flatbuffers::Offset<mobile::serialization::Function> FlatbufferSerializer::
"A variable number of return values is not supported in mobile modules.");
schema_offset =
CreateFBSchema(fbb, schema.arguments(), schema.returns(), type_printer);
auto classtype = schema.arguments()[0].type()->cast<ClassType>();
class_index = storeClassTypeAndGetIndex(fbb, classtype);
}

auto debug_info_offset =
CreateDebugInfo(fbb, fbb.CreateVector(code.debug_handles_));

// auto classtype = schema.arguments()[0].type()->cast<ClassType>();
// uint32_t class_type = storeClassTypeAndGetIndex(fbb, classtype);

auto function_offset = CreateFunctionDirect(
fbb,
qn.c_str(),
Expand All @@ -267,7 +267,7 @@ flatbuffers::Offset<mobile::serialization::Function> FlatbufferSerializer::
register_size,
schema_offset,
debug_info_offset,
0);
class_index);
return function_offset;
}

Expand Down

0 comments on commit 57f039b

Please sign in to comment.