Skip to content

Commit

Permalink
[BE][flatbuffer] Remove code duplications and refactor (pytorch#79184)
Browse files Browse the repository at this point in the history
Summary:
Remove code dup in import.cpp / export_modules.cpp such that
1. Only one copy of switching logic (detect flatbuffer / is_flatbuffer);
2. Move detection of includeness of flatbuffer to runtime (so no more macros)

This also reverts the dependency of import.cpp -> flatbuffer_loader.cpp to flatbuffer_loader.cpp -> import.cpp.

Differential Revision: D36926217

Pull Request resolved: pytorch#79184
Approved by: https://github.com/zhxchen17
  • Loading branch information
qihqi authored and pytorchmergebot committed Jun 20, 2022
1 parent 7de231a commit fed12ff
Show file tree
Hide file tree
Showing 24 changed files with 617 additions and 560 deletions.
3 changes: 2 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1645,7 +1645,8 @@ cc_library(
],
)) + libtorch_core_sources + libtorch_distributed_sources + torch_cpp_srcs + libtorch_extra_sources + jit_core_sources + lazy_tensor_ts_sources + GENERATED_AUTOGRAD_CPP + [
"torch/csrc/jit/serialization/flatbuffer_serializer.cpp",
"torch/csrc/jit/mobile/flatbuffer_loader.cpp"
"torch/csrc/jit/mobile/flatbuffer_loader.cpp",
"torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp",
],
copts = TORCH_COPTS,
defines = [
Expand Down
32 changes: 32 additions & 0 deletions caffe2/serialize/in_memory_adapter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#pragma once
#include <cstring>
#include <caffe2/serialize/read_adapter_interface.h>


namespace caffe2 {
namespace serialize {

class MemoryReadAdapter final : public caffe2::serialize::ReadAdapterInterface {
public:
explicit MemoryReadAdapter(const void* data, off_t size)
: data_(data), size_(size) {}

size_t size() const override {
return size_;
}

size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override {
(void) what;
memcpy(buf, (int8_t*)(data_) + pos, n);
return n;
}

private:
const void* data_;
off_t size_;
};


} // namespace serialize
} // namespace caffe2
18 changes: 2 additions & 16 deletions test/cpp/jit/test_flatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ TEST(FlatbufferTest, MethodInvocation) { // NOLINT (use =delete in gtest)
}
}

#if defined(ENABLE_FLATBUFFER) && !defined(FB_XPLAT_BUILD)
#if !defined(FB_XPLAT_BUILD)
TEST(FlatbufferTest, FlatbufferBackPortTest) {
Module m("m");
m.define(R"(
Expand All @@ -188,7 +188,7 @@ TEST(FlatbufferTest, FlatbufferBackPortTest) {
bool backPortSuccess = _backport_for_mobile(ss, oss, 5);
ASSERT_TRUE(backPortSuccess);
}
#endif // defined(ENABLE_FLATBUFFER) && !defined(FB_XPLAT_BUILD)
#endif // !defined(FB_XPLAT_BUILD)

TEST(FlatbufferTest, ExtraFiles) {
const auto script = R"JIT(
Expand All @@ -207,7 +207,6 @@ TEST(FlatbufferTest, ExtraFiles) {
extra_files["mobile_info.json"] = "{\"key\": 23}";

std::unordered_map<std::string, std::string> loaded_extra_files;
#if defined ENABLE_FLATBUFFER
std::stringstream ss;
module->_save_for_mobile(ss, extra_files, true, /*use_flatbuffer=*/true);

Expand All @@ -219,17 +218,6 @@ TEST(FlatbufferTest, ExtraFiles) {

// load it twice using the same stream
auto mobile_module2 = _load_for_mobile(ss, c10::nullopt, loaded_extra_files);
#else
CompilationOptions options;
mobile::Module bc = jitModuleToMobile(*module, options);
auto buff = save_mobile_module_to_bytes(bc, extra_files);

loaded_extra_files["metadata.json"] = "";
auto* flatbuffer_module =
mobile::serialization::GetMutableModule(buff.data());

parseExtraFiles(flatbuffer_module, loaded_extra_files);
#endif

ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
Expand Down Expand Up @@ -1283,7 +1271,6 @@ Module jitModuleFromBuffer(void* data) {
mobilem._ivalue(), files, constants, 8);
}

#if defined(ENABLE_FLATBUFFER)
TEST(TestSourceFlatbuffer, UpsampleNearest2d) {
Module m("m");
m.define(R"(
Expand Down Expand Up @@ -1375,7 +1362,6 @@ TEST(TestSourceFlatbuffer,
AT_ASSERT(resd == refd);
}
}
#endif

#if !defined FB_XPLAT_BUILD
// The following test run in fbcode only
Expand Down
8 changes: 2 additions & 6 deletions test/cpp/jit/test_lite_interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ void backportAllVersionCheck(
std::vector<IValue>& expect_result_list,
const uint64_t expect_from_version) {
auto from_version = _get_model_bytecode_version(test_model_file_stream);
AT_ASSERT(from_version == expect_from_version);
EXPECT_EQ(from_version, expect_from_version);
AT_ASSERT(from_version > 0);

// Backport script_module_v5.ptl to an older version
Expand Down Expand Up @@ -717,15 +717,11 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
torch::jit::Module module_freeze = freeze(module);

std::stringstream input_model_stream;
#if defined(ENABLE_FLATBUFFER)
module_freeze._save_for_mobile(
input_model_stream,
/*extra_files=*/{},
/*save_mobile_debug_info=*/false,
/*use_flatbuffer=*/true);
#else
module_freeze._save_for_mobile(input_model_stream);
#endif
std::vector<IValue> input_data =
std::vector<IValue>({torch::ones({1, 1, 28, 28})});
std::vector<IValue> expect_result_list;
Expand All @@ -748,7 +744,7 @@ TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
input_model_stream,
input_data,
expect_result_list,
caffe2::serialize::kProducedBytecodeVersion);
9); // flatbuffer starts at 9
}
#endif // !defined(FB_XPLAT_BUILD)

Expand Down
56 changes: 3 additions & 53 deletions test/cpp/jit/test_lite_trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <torch/csrc/jit/mobile/train/optim/sgd.h>
#include <torch/csrc/jit/mobile/train/random.h>
#include <torch/csrc/jit/mobile/train/sequential.h>
#include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/data/dataloader.h>
#include <torch/torch.h>
Expand Down Expand Up @@ -172,9 +173,9 @@ TEST(MobileTest, SaveParametersDefaultsToZip) {
EXPECT_EQ(ss_data.str()[3], '\x04');
}

#if defined(ENABLE_FLATBUFFER)
TEST(MobileTest, SaveParametersCanUseFlatbuffer) {
// Save some empty parameters using flatbuffer.
register_flatbuffer_all();
std::map<std::string, at::Tensor> empty_parameters;
std::stringstream ss_data;
_save_parameters(empty_parameters, ss_data, /*use_flatbuffer=*/true);
Expand All @@ -188,34 +189,10 @@ TEST(MobileTest, SaveParametersCanUseFlatbuffer) {
EXPECT_EQ(ss_data.str()[6], 'M');
EXPECT_EQ(ss_data.str()[7], 'F');
}
#else // !defined(ENABLE_FLATBUFFER)
TEST(MobileTest, SaveParametersThrowsWithoutFlatbufferSupport) {
// Some empty parameters to try saving.
std::map<std::string, at::Tensor> empty_parameters;
std::stringstream ss_data;

// Save using flatbuffers should fail when support isn't compiled in. Make
// sure we get the exception that explicitly mentions the lack of flatbuffer
// support.
try {
_save_parameters(empty_parameters, ss_data, /*use_flatbuffer=*/true);
FAIL() << "_save_parameters should have thrown";
} catch (const ::c10::Error& e) {
static const std::string kExpectedSubstring =
"build hasn't enabled flatbuffer";
EXPECT_TRUE(
std::string(e.msg()).find(kExpectedSubstring) != std::string::npos)
<< "Exception message does not contain expected substring \""
<< kExpectedSubstring << "\": actual message \"" << e.msg() << "\"";
} catch (...) {
FAIL() << "Unexpected exception type";
}
}
#endif // !defined(ENABLE_FLATBUFFER)

#if defined(ENABLE_FLATBUFFER)
TEST(MobileTest, SaveLoadParametersUsingFlatbuffers) {
// Create some simple parameters to save.
register_flatbuffer_all();
std::map<std::string, at::Tensor> input_params;
input_params["four_by_ones"] = 4 * torch::ones({});
input_params["three_by_ones"] = 3 * torch::ones({});
Expand Down Expand Up @@ -244,33 +221,6 @@ TEST(MobileTest, SaveLoadParametersUsingFlatbuffers) {
output_params["three_by_ones"].item<int>(), three_by_ones.item<int>());
}
}
#else // !defined(ENABLE_FLATBUFFER)
TEST(MobileTest, LoadParametersFailsWithoutFlatbufferSupport) {
// Create some data that looks like a flatbuffer header.
std::stringstream data;
data << "abcd"
<< "PTMF" // Flatbuffer magic
<< "ijkl";

// Loading the "flatbuffer" data should fail. Make sure we see the expected
// exception, not just any exception; since this isn't properly-formed
// flatbuffer data, any attempt to parse it might throw a different error type
// or message, but we don't expect anyone to try parsing it.
try {
_load_parameters(data);
FAIL() << "_load_parameters should have thrown";
} catch (const ::c10::Error& e) {
static const std::string kExpectedSubstring =
"build hasn't enabled flatbuffer";
EXPECT_TRUE(
std::string(e.msg()).find(kExpectedSubstring) != std::string::npos)
<< "Exception message does not contain expected substring \""
<< kExpectedSubstring << "\": actual message \"" << e.msg() << "\"";
} catch (...) {
FAIL() << "Unexpected exception type";
}
}
#endif // !defined(ENABLE_FLATBUFFER)

TEST(MobileTest, LoadParametersUnexpectedFormatShouldThrow) {
// Manually create some data that doesn't look like a ZIP or Flatbuffer file.
Expand Down
2 changes: 1 addition & 1 deletion tools/target_definitions.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def add_torch_libs():
] if enable_flatbuffer else []),
link_whole = True,
include_directories = include_directories,
propagated_pp_flags = propagated_pp_flags_cpu + (["-DENABLE_FLATBUFFER"] if enable_flatbuffer else []),
propagated_pp_flags = propagated_pp_flags_cpu,
exported_deps = (
[
":ATen-cpu",
Expand Down
87 changes: 42 additions & 45 deletions torch/csrc/jit/mobile/compatibility/model_compatibility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@
#include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
#include <torch/csrc/jit/mobile/file_format.h>
#if defined(ENABLE_FLATBUFFER)
#include <torch/csrc/jit/mobile/flatbuffer_loader.h>
#endif
#include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
#include <torch/csrc/jit/mobile/type_parser.h>
#include <torch/csrc/jit/serialization/import_export_constants.h>
#include <torch/csrc/jit/serialization/import_read.h>

#include <caffe2/serialize/in_memory_adapter.h>
#include <sstream>
#include <string>
#include <unordered_set>
Expand Down Expand Up @@ -71,59 +69,33 @@ std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
// Forward declare
uint64_t _get_model_bytecode_version(
const std::vector<IValue>& bytecode_ivalues);
static uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size);

uint64_t _get_model_bytecode_version(std::istream& in) {
auto orig_pos = in.tellg();
in.seekg(0, in.beg);
auto format = getFileFormat(in);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if !defined(ENABLE_FLATBUFFER)
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
#else
return get_bytecode_version(in);
#endif
}
case FileFormat::ZipFileFormat: {
std::unique_ptr<IStreamAdapter> rai =
std::make_unique<IStreamAdapter>(&in);
auto version = _get_model_bytecode_version(std::move(rai));
in.seekg(orig_pos, in.beg);
return version;
}

default:
TORCH_CHECK(false, "Unrecognized data format");
}
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_stream_content(in);
in.seekg(orig_pos, in.beg);
return _get_model_bytecode_version_from_bytes(data.get(), size);
}

uint64_t _get_model_bytecode_version(const std::string& filename) {
auto format = getFileFormat(filename);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
#if !defined(ENABLE_FLATBUFFER)
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
#else
return get_bytecode_version(filename);
#endif
}
case FileFormat::ZipFileFormat: {
std::unique_ptr<FileAdapter> rai =
std::make_unique<FileAdapter>(filename);
return _get_model_bytecode_version(std::move(rai));
}

default:
TORCH_CHECK(false, "Unrecognized data format");
}
std::ifstream ifile(filename);
return _get_model_bytecode_version(ifile);
}

uint64_t _get_model_bytecode_version(
std::shared_ptr<ReadAdapterInterface> rai) {
std::shared_ptr<char> data;
size_t size = 0;
std::tie(data, size) = get_rai_content(rai.get());
return _get_model_bytecode_version_from_bytes(data.get(), size);
}

uint64_t _get_model_bytecode_version_zip(
std::shared_ptr<ReadAdapterInterface> rai) {
if (!check_zip_file(rai)) {
TORCH_CHECK(
false,
Expand All @@ -134,6 +106,31 @@ uint64_t _get_model_bytecode_version(
return _get_model_bytecode_version(bytecode_values);
}

uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size) {
TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
auto format = getFileFormat(data);
switch (format) {
case FileFormat::FlatbufferFileFormat: {
if (get_flatbuffer_bytecode_version == nullptr) {
TORCH_CHECK(
false,
"Flatbuffer input file but the build hasn't enabled flatbuffer");
} else {
return get_flatbuffer_bytecode_version(data);
}
}
case FileFormat::ZipFileFormat: {
auto rai =
std::make_unique<caffe2::serialize::MemoryReadAdapter>(data, size);
auto version = _get_model_bytecode_version_zip(std::move(rai));
return version;
}

default:
TORCH_CHECK(false, "Unrecognized data format");
}
}

uint64_t _get_model_bytecode_version(
const std::vector<IValue>& bytecode_ivalues) {
if (!bytecode_ivalues.empty() && bytecode_ivalues[0].isInt()) {
Expand Down
Loading

0 comments on commit fed12ff

Please sign in to comment.