diff --git a/test/test_utils.py b/test/test_utils.py index cb0eb336c175aa..07c532b81c96c4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -628,6 +628,13 @@ def test_get_set_dir(self): self.assertEqual(sum_of_state_dict(hub_model.state_dict()), SUM_OF_HUB_EXAMPLE) assert os.path.exists(os.path.join(tmpdir, 'ailzhang_torchhub_example_master')) + # Test that set_dir properly calls expanduser() + # non-regression test for https://github.com/pytorch/pytorch/issues/69761 + new_dir = os.path.join("~", "hub") + torch.hub.set_dir(new_dir) + self.assertEqual(torch.hub.get_dir(), os.path.expanduser(new_dir)) + + @retry(Exception, tries=3) def test_list_entrypoints(self): entry_lists = hub.list('ailzhang/torchhub_example', force_reload=True) diff --git a/torch/hub.py b/torch/hub.py index 4b36117b74a4b5..4a951ff27b4fa9 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -283,7 +283,7 @@ def set_dir(d): d (string): path to a local folder to save downloaded models & weights. """ global _hub_dir - _hub_dir = d + _hub_dir = os.path.expanduser(d) def list(github, force_reload=False, skip_validation=False):