Skip to content

Commit

Permalink
set_dir expanding "~"
Browse files Browse the repository at this point in the history
Fixes pytorch#69761.

Small change to torch.hub.set_dir() (<10 LOC).

It seems that before the code was split into `set_dir()` and `_get_torch_home `, an [earlier version](https://github.com/pytorch/pytorch/blame/5164622ba462fe07fc9f2325fccf7f85aecb3ec8/torch/hub.py#L111) of hub.py had a os.path.expanduser check.

Currently, [_get_torch_home](https://github.com/pytorch/pytorch/blob/master/torch/hub.py#L104) retained the os.path.expanduser check, but `set_dir()` didn't have one. This PR fixes that (I hope).

(As I mentioned in the issue, I can't run the tests on my laptop yet because of storage space :/ But I did include a test.)
Pull Request resolved: pytorch#69763
Approved by: https://github.com/malfet, https://github.com/NicolasHug
  • Loading branch information
loodvn authored and pytorchmergebot committed Mar 23, 2022
1 parent 93a1068 commit 670e4d9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
7 changes: 7 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,13 @@ def test_get_set_dir(self):
self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE)
assert os.path.exists(os.path.join(tmpdir, 'ailzhang_torchhub_example_master'))

# Test that set_dir properly calls expanduser()
# non-regression test for https://github.com/pytorch/pytorch/issues/69761
new_dir = os.path.join("~", "hub")
torch.hub.set_dir(new_dir)
self.assertEqual(torch.hub.get_dir(), os.path.expanduser(new_dir))


@retry(Exception, tries=3)
def test_list_entrypoints(self):
entry_lists = hub.list('ailzhang/torchhub_example', force_reload=True)
Expand Down
2 changes: 1 addition & 1 deletion torch/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def set_dir(d):
d (string): path to a local folder to save downloaded models & weights.
"""
global _hub_dir
_hub_dir = d
_hub_dir = os.path.expanduser(d)


def list(github, force_reload=False, skip_validation=False):
Expand Down

0 comments on commit 670e4d9

Please sign in to comment.