Skip to content

Commit

Permalink
[package] Allow save_module to accept module as arg (pytorch#55996)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#55996

**Sumamary**
This commit modifies `PackageExporter.save_module` so that the `module`
argument can be either a string (`str`) or a module
(`types.ModuleType`).

**Test Plan**
This commit adds a unit test similar to `TestSaveLoad.test_save_module`
that tests that calling `save_module` with a module object works.

**Fixes**
This commit fixes pytorch#55939.

Test Plan: Imported from OSS

Reviewed By: jamesr66a, huiguoo

Differential Revision: D27771781

Pulled By: SplitInfinity

fbshipit-source-id: 57c8cf45575bb8dcfca711759fadfff72efb35e7
  • Loading branch information
Meghan Lele authored and facebook-github-bot committed Apr 14, 2021
1 parent 1a116a9 commit 669a8ac
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
18 changes: 18 additions & 0 deletions test/package/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ def test_save_module(self):
self.assertEqual(package_a_i.result, "package_a")
self.assertIsNot(package_a_i, package_a)

def test_save_module_with_module_object(self):
"""
Test that save_module works with a module object
instead of a module name.
"""
buffer = BytesIO()

with PackageExporter(buffer, verbose=False) as he:
import module_a

he.save_module(module_a)

buffer.seek(0)
hi = PackageImporter(buffer)
module_a_i = hi.import_module("module_a")
self.assertEqual(module_a_i.result, "module_a")
self.assertIsNot(module_a, module_a_i)

def test_save_module_binary(self):
f = BytesIO()
with PackageExporter(f, verbose=False) as he:
Expand Down
21 changes: 14 additions & 7 deletions torch/package/package_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,22 +349,29 @@ def require_module(self, module_name: str, dependencies=True):

self.save_module(module_name, dependencies)

def save_module(self, module_name: str, dependencies=True):
"""Save the code for `module_name` into the package. Code for the module is resolved using the `importers` path to find the
def save_module(self, module: Union[str, types.ModuleType], dependencies=True):
"""Save the code for `module` into the package. Code for the module is resolved using the `importers` path to find the
module object, and then using its `__file__` attribute to find the source code.
Args:
module_name (str): e.g. `my_package.my_subpackage`, code will be saved to provide code for this package.
module (Union[str, types.ModuleType]): e.g. `my_package.my_subpackage`, code will be saved to provide code
for this package.
dependencies (bool, optional): If True, we scan the source for dependencies.
"""
module = self._import_module(module_name)
source = self._get_source_of_module(module)
if isinstance(module, str):
module_name = module
module_obj = self._import_module(module_name)
else:
module_name = module.__name__
module_obj = module

source = self._get_source_of_module(module_obj)
self.save_source_string(
module_name,
source,
hasattr(module, "__path__"),
hasattr(module_obj, "__path__"),
dependencies,
module.__file__,
module_obj.__file__,
)

def save_pickle(
Expand Down

0 comments on commit 669a8ac

Please sign in to comment.