Skip to content

Commit

Permalink
Attribute serialization improvements (pytorch#18188)
Browse files Browse the repository at this point in the history
Summary:
* adds attributes to `ScriptModule.__getattr__` so they can be accessed in Python after re-importing
* full support for all the possible values for an `int64_t`
    * this necessitated a bunch more `pushWhatever` functions, so re-introduced a templated version to cut down on duplicate code
* tests to validate references / value sharing works
* adds `torch.jit.Unpickler` which people can use to de-serialize the pickle files into Python / have a quick reference on how to do this without PyTorch
Pull Request resolved: pytorch#18188

Differential Revision: D14527490

Pulled By: driazati

fbshipit-source-id: efd15579cc04aa2e28c4b2c9490d82d849dee559
  • Loading branch information
David Riazati authored and facebook-github-bot committed Mar 30, 2019
1 parent e13101e commit 24db166
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 75 deletions.
70 changes: 67 additions & 3 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.autograd.function import traceable
from torch.testing import assert_allclose
from torch.onnx import OperatorExportTypes
from torch._six import inf, PY2, builtins
from torch._six import inf, PY2, builtins, StringIO
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
freeze_rng_state, set_rng_seed, slowTest
Expand All @@ -37,7 +37,10 @@
import math
import types
import pickle
import pickletools
import copy
import zipfile


from common_methods_invocations import method_tests as autograd_method_tests
from common_methods_invocations import create_input, unpack_variables, \
Expand Down Expand Up @@ -10488,8 +10491,6 @@ def forward(self):

@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
def test_attribute_unpickling(self):
import zipfile

class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__()
Expand Down Expand Up @@ -10557,6 +10558,69 @@ def forward(self):
imported_m = self.getExportImportCopy(m)
self.assertEqual(m(), imported_m())

def test_serialization_big_ints(self):
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__()
self.int32_max = torch.jit.Attribute(2**31 - 1, int)
self.int32_min = torch.jit.Attribute(-2**31, int)
self.uint32_max = torch.jit.Attribute(2**32, int)

self.int64_max = torch.jit.Attribute(2**63 - 1, int)
self.int64_min = torch.jit.Attribute(-2**63, int)

self.tensor = torch.nn.Parameter(torch.ones(2, 2))

@torch.jit.script_method
def forward(self, x):
# type: (int) -> (int)
return x + (self.int32_max + self.int32_min) + (self.int64_max + self.int64_min)

m = M()
imported = self.getExportImportCopy(m)
self.assertEqual(m(10), imported(10))

self.assertEqual(m.int32_max, imported.int32_max)
self.assertEqual(m.int32_min, imported.int32_min)
self.assertEqual(m.uint32_max, imported.uint32_max)
self.assertEqual(m.int64_max, imported.int64_max)
self.assertEqual(m.int64_min, imported.int64_min)

def test_serialization_sharing(self):
class M(torch.jit.ScriptModule):
def __init__(self):
super(M, self).__init__()
self.list = torch.jit.Attribute([], List[str])

@torch.jit.script_method
def forward(self, key):
# type: (str) -> List[str]
self.list.append(key)
self.list.append(key)
self.list.append(key)
return self.list

# the text of the string should only appear once in the pickling
m = M()
s1 = "a long string"
s2 = "a different, even longer string"
self.assertEqual(m(s1), [s1] * 3)
self.assertEqual(m(s2), [s1] * 3 + [s2] * 3)
with TemporaryFileName() as fname:
m.save(fname)
archive_name = os.path.basename(os.path.normpath(fname))
archive = zipfile.ZipFile(fname, 'r')
pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl'))

out = StringIO()
pickletools.dis(pickled_data, out=out)
disassembled = out.getvalue()

FileCheck().check_count(s1, 1, exactly=True) \
.check_count("BINGET", 2, exactly=True) \
.check_count(s2, 1, exactly=True) \
.check_count("BINGET", 2, exactly=True).run(out.getvalue())

def test_optional_tuple(self):
def fn(x=None):
# type: (Optional[Tuple[int, int]]) -> Tuple[int, int]
Expand Down
7 changes: 7 additions & 0 deletions torch/_six.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ def get_function_from_type(cls, name):
elif PY3:
import builtins

if PY2:
import StringIO
StringIO = StringIO.StringIO
elif PY3:
import io
StringIO = io.StringIO


# The codes below is not copied from the six package, so the copyright
# declaration at the beginning does not apply.
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/jit/passes/python_print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,9 +827,9 @@ struct PythonPrintPass {
if (enforce_importable_) {
throw script::ErrorReport(node->getSourceLocation())
<< "could not export python function call " << value->name()
<< ". Remove calls to Python functions before export."
<< "Did you forget add @script annotation? "
<< "If this is a modulelist, add it to __constants__.";
<< ". Remove calls to Python functions before export. "
<< "Did you forget add @script or @script_method annotation? "
<< "If this is a nn.ModuleList, add it to __constants__.";
}

stmt << "^" << value->name();
Expand Down
Loading

0 comments on commit 24db166

Please sign in to comment.