Skip to content

Commit

Permalink
Make debug_pkl smaller by only emitting unique traces. (pytorch#73368)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#73368

debug_pkl file inside of pytorch's .pt file consists of a list of SourceRanges. Each SourceRange points to a Source which is a stack track, filename, and start, end numbers. Those are emitted in debug_pkl file as strings.
Since many SourceRange shares the same source, the string for trace can be deduped.
The newer format saves a set of unique traces in a tuple, then each SourceRange will save the offset of it's trace w.r.t. position in that tuple. (i.e. manually applying dictionary compression).
The above helps with smaller file size. On loading, if we copy each trace to Source as string the runtime memory would still blowup.
To mitigate this, we use SourceView directly instead of source which will take the reference of string inside of Deserializer and make that into string_view. This is safe because Deserializer is hold by Unpickler by shared_ptr, and Unpickler is also hold by shared_ptr by another Source object. That Source object will be alive during the model construction.

Test Plan:
unit test

Took original file (312271638_930.predictor.disagg.local); loaded with `torch.jit.load` save again with `torch.jit.save`. Unzip both, look at contents:
```
[[email protected] ~]$ du archive -h
4.0K    archive/xl_model_weights
3.7M    archive/extra
8.0K    archive/code/__torch__/caffe2/torch/fb/model_transform/splitting
8.0K    archive/code/__torch__/caffe2/torch/fb/model_transform
8.0K    archive/code/__torch__/caffe2/torch/fb
8.0K    archive/code/__torch__/caffe2/torch
8.0K    archive/code/__torch__/caffe2
20M     archive/code/__torch__/torch/fx/graph_module
20M     archive/code/__torch__/torch/fx
8.0K    archive/code/__torch__/torch/classes
20M     archive/code/__torch__/torch
20M     archive/code/__torch__
20M     archive/code
2.7M    archive/constants
35M     archive
[[email protected] ~]$ du resaved -h
4.0K    resaved/extra
8.0K    resaved/code/__torch__/caffe2/torch/fb/model_transform/splitting
8.0K    resaved/code/__torch__/caffe2/torch/fb/model_transform
8.0K    resaved/code/__torch__/caffe2/torch/fb
8.0K    resaved/code/__torch__/caffe2/torch
8.0K    resaved/code/__torch__/caffe2
1.3M    resaved/code/__torch__/torch/fx/graph_module
1.3M    resaved/code/__torch__/torch/fx
8.0K    resaved/code/__torch__/torch/classes
1.4M    resaved/code/__torch__/torch
1.4M    resaved/code/__torch__
1.4M    resaved/code
2.7M    resaved/constants
13M     resaved
[[email protected] ~]$
```

Reviewed By: gmagogsfm

Differential Revision: D34455360

fbshipit-source-id: 8cc716f9bba7183746b1b4ecc33a2de34ac503b9
(cherry picked from commit f1a0473)
  • Loading branch information
qihqi authored and pytorchmergebot committed Mar 2, 2022
1 parent 07df887 commit 61d6c43
Show file tree
Hide file tree
Showing 25 changed files with 690 additions and 204 deletions.
32 changes: 32 additions & 0 deletions test/cpp/jit/source_range_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include <gtest/gtest.h>
#include <torch/csrc/jit/frontend/source_range.h>

using namespace ::testing;
using namespace ::torch::jit;

TEST(SourceRangeTest, test_find) {
std::vector<std::shared_ptr<std::string>> strings;
strings.push_back(std::make_shared<std::string>("hello world"));
strings.push_back(std::make_shared<std::string>("nihaoma"));

std::vector<c10::string_view> pieces{*strings[0], *strings[1]};

StringCordView view(pieces, strings);

auto x = view.find("rldni", 0);
EXPECT_EQ(x, 8);
}

TEST(SourceRangeTest, test_substr) {
std::vector<std::shared_ptr<std::string>> strings;
strings.push_back(std::make_shared<std::string>("hello world"));
strings.push_back(std::make_shared<std::string>("nihaoma"));

std::vector<c10::string_view> pieces{*strings[0], *strings[1]};

StringCordView view(pieces, strings);

auto x = view.substr(4, 10).str();
EXPECT_EQ(x, view.str().substr(4, 10));
EXPECT_EQ(view.substr(0, view.size()).str(), view.str());
}
82 changes: 82 additions & 0 deletions test/cpp/jit/test_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,38 @@ TEST(BackendTest, TestCompiler) {
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
}

TEST(BackendTest, TestCompilerWithStringTable) {
setShouldUseFormatWithStringTable(true);
Module m("m");
m.define(R"(
def forward(self, x, h):
return x + h
)");

std::vector<IValue> inputs;
inputs.emplace_back(2.0 * torch::ones({}));
inputs.emplace_back(1.0 * torch::ones({}));
auto ref = m.forward(inputs);

c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
fake_dict.insert("", "");
compile_spec.insert("forward", fake_dict);
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// lowered module
auto lm = torch::jit::detail::codegen_backend_module(
"backend_with_compiler_demo", m, compile_spec, any_dict_ty);
auto res = lm.forward(inputs);
AT_ASSERT(res.toTensor().equal(ref.toTensor()));

std::stringstream ss;
lm._save_for_mobile(ss);
auto mlm = _load_for_mobile(ss);
auto mres = mlm.forward(inputs);
setShouldUseFormatWithStringTable(false);
AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
}

TEST(BackendTest, TestComposite) {
c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
Expand Down Expand Up @@ -383,6 +415,56 @@ Traceback of TorchScript (most recent call last):
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
}

TEST(BackendTestDebugInfo, TestCompilerWithStringTable) {
setShouldUseFormatWithStringTable(true);
Module m("m");
m.define(R"(
def forward(self, x, h):
return x + h
)");

std::vector<IValue> inputs;
inputs.emplace_back(torch::rand({2, 4}));
inputs.emplace_back(torch::rand({13, 9}));

c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
fake_dict.insert("", "");
compile_spec.insert("forward", fake_dict);
auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
// lowered module
auto lm = torch::jit::detail::codegen_backend_module(
"backend_with_compiler_demo", m, compile_spec, any_dict_ty);

std::stringstream ss;
lm._save_for_mobile(ss, ExtraFilesMap(), true);
auto mlm = _load_for_mobile(ss);
std::string error_pattern = R"(
Module hierarchy:top(m)::<unknown>.__loweredModule__(m)::forward.aten::add
Traceback of TorchScript (most recent call last):
File "<string>", line 3, in <unknown>
def forward(self, x: Tensor, h: Tensor):
return self.__loweredModule__.forward(x, h)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
File "<string>", line 5, in forward
typed_inputs: List[Any] = [x, h, ]
if self.__backend.is_available() :
_0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
assert isinstance(_0, Tensor)
return _0
File "<string>", line 3, in <unknown>
def forward(self, x, h):
return x + h
~~~~~ <--- HERE
)";
setShouldUseFormatWithStringTable(false);
ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
}

TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) {
Module a("A");
a.define(R"(
Expand Down
22 changes: 12 additions & 10 deletions test/cpp/jit/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,18 @@ static inline void trim(std::string& s) {
}
} // namespace

#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
try { \
(void)statement; \
FAIL(); \
} catch (const std::exception& e) { \
std::string substring_s(substring); \
trim(substring_s); \
auto exception_string = std::string(e.what()); \
trim(exception_string); \
ASSERT_NE(exception_string.find(substring_s), std::string::npos); \
#define ASSERT_THROWS_WITH_MESSAGE(statement, substring) \
try { \
(void)statement; \
FAIL(); \
} catch (const std::exception& e) { \
std::string substring_s(substring); \
trim(substring_s); \
auto exception_string = std::string(e.what()); \
trim(exception_string); \
ASSERT_NE(exception_string.find(substring_s), std::string::npos) \
<< " Error was: \n" \
<< exception_string; \
}

namespace torch {
Expand Down
10 changes: 8 additions & 2 deletions torch/csrc/jit/frontend/function_schema_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/core/Reduction.h>
#include <ATen/core/type_factory.h>
#include <c10/util/Optional.h>
#include <c10/util/string_utils.h>
#include <torch/csrc/jit/frontend/lexer.h>
#include <torch/csrc/jit/frontend/parse_string_literal.h>
Expand All @@ -27,8 +28,13 @@ namespace jit {

namespace {
struct SchemaParser {
SchemaParser(const std::string& str)
: L(std::make_shared<SourceView>(c10::string_view(str))),
explicit SchemaParser(const std::string& str)
: L(std::make_shared<Source>(
c10::string_view(str),
c10::nullopt,
0,
nullptr,
Source::DONT_COPY)),
type_parser(L, /*parse_complete_tensor_types*/ false) {}

either<OperatorName, FunctionSchema> parseDeclaration() {
Expand Down
26 changes: 10 additions & 16 deletions torch/csrc/jit/frontend/lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ struct TORCH_API SharedParserData {
// find the longest match of str.substring(pos) against a token, return true
// if successful filling in kind, start,and len
bool match(
c10::string_view str,
StringCordView str,
size_t pos,
bool continuation, // are we inside a scope where newlines don't count
// (e.g. inside parens)
Expand Down Expand Up @@ -241,12 +241,12 @@ struct TORCH_API SharedParserData {
// invariant: the next token is not whitespace or newline
*start = pos;
// check for a valid number
if (isNumber(str, pos, len)) {
if (isNumber(str.piece(0), pos, len)) {
*kind = TK_NUMBER;
return true;
}
// check for string
if (isString(str, pos, len)) {
if (isString(str.piece(0), pos, len)) {
*kind = TK_STRINGLITERAL;
return true;
}
Expand Down Expand Up @@ -369,7 +369,7 @@ struct TORCH_API SharedParserData {
return isspace(n) && n != '\n';
}
// Make an exception ignoring comments for type annotation comments
bool isTypeComment(c10::string_view str, size_t pos) {
bool isTypeComment(StringCordView str, size_t pos) {
const std::string type_string = "# type:";
if (str.size() < pos + type_string.length()) {
return false;
Expand All @@ -388,15 +388,15 @@ struct Token {
SourceRange range;
Token(int kind, SourceRange range) : kind(kind), range(std::move(range)) {}
std::string text() {
return range.text();
return range.text().str();
}
std::string kindString() const {
return kindToString(kind);
}
};

struct Lexer {
explicit Lexer(std::shared_ptr<SourceView> source)
explicit Lexer(std::shared_ptr<Source> source)
: source(std::move(source)),
pos(0),
nesting(0),
Expand Down Expand Up @@ -519,25 +519,19 @@ struct Lexer {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
size_t length;
AT_ASSERT(source);
auto src = source->text_str();
if (!shared.match(
source->text(),
pos,
nesting > 0,
whitespace_token,
&kind,
&start,
&length)) {
src, pos, nesting > 0, whitespace_token, &kind, &start, &length)) {
expected(
"a valid token",
Token(
(source->text())[start], SourceRange(source, start, start + 1)));
Token(source->char_at(start), SourceRange(source, start, start + 1)));
}
auto t = Token(kind, SourceRange(source, start, start + length));
pos = start + length;
return t;
}

std::shared_ptr<SourceView> source;
std::shared_ptr<Source> source;
size_t pos;
size_t nesting; // depth of ( [ { nesting...
std::vector<int> indent_stack; // stack of indentation level of blocks
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/frontend/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Decl mergeTypesFromTypeComment(
}

struct ParserImpl {
explicit ParserImpl(const std::shared_ptr<SourceView>& source)
explicit ParserImpl(const std::shared_ptr<Source>& source)
: L(source), shared(sharedParserData()) {}

Ident parseIdent() {
Expand Down Expand Up @@ -801,7 +801,7 @@ struct ParserImpl {
SharedParserData& shared;
};

Parser::Parser(const std::shared_ptr<SourceView>& src)
Parser::Parser(const std::shared_ptr<Source>& src)
: pImpl(new ParserImpl(src)) {}

Parser::~Parser() = default;
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/frontend/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ TORCH_API Decl mergeTypesFromTypeComment(
bool is_method);

struct TORCH_API Parser {
explicit Parser(const std::shared_ptr<SourceView>& src);
explicit Parser(const std::shared_ptr<Source>& src);
TreeRef parseFunction(bool is_method);
TreeRef parseClass();
Decl parseTypeComment();
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/frontend/script_type_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ TypePtr ScriptTypeParser::parseTypeFromExpr(const Expr& expr) const {
// expression and base type names.
if (resolver_) {
if (auto typePtr =
resolver_->resolveType(expr.range().text(), expr.range())) {
resolver_->resolveType(expr.range().text().str(), expr.range())) {
return typePtr;
}
}
Expand Down
Loading

0 comments on commit 61d6c43

Please sign in to comment.