Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.
Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights)
to a github repository by adding a simple hubconf.py
file;
hubconf.py
can have multiple entrypoints. Each entrypoint is defined as a python function with
the following signature.
def entrypoint_name(pretrained=False, *args, **kwargs):
...
Here is a code snipet from pytorch/vision repository, which specifies an entrypoint
for resnet18
model. You can see a full script in
pytorch/vision repo
dependencies = ['torch', 'math']
def resnet18(pretrained=False, *args, **kwargs):
"""
Resnet18 model
pretrained (bool): a recommended kwargs for all entrypoints
args & kwargs are arguments for the function
"""
######## Call the model in the repo ###############
from torchvision.models.resnet import resnet18 as _resnet18
model = _resnet18(*args, **kwargs)
######## End of call ##############################
# The following logic is REQUIRED
if pretrained:
# For weights saved in local repo
# model.load_state_dict(<path_to_saved_file>)
# For weights saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(model_zoo.load_url(checkpoint, progress=False))
return model
dependencies
variable is a list of package names required to to run the model.- Pretrained weights can either be stored local in the github repo, or loadable by
model_zoo.load()
. pretrained
controls whether to load the pre-trained weights provided by repo owners.args
andkwargs
are passed along to the real callable function.- Docstring of the function works as a help message, explaining what does the model do and what are the allowed arguments.
- Entrypoint function should ALWAYS return a model(nn.module).
- The published models should be at least in a branch/tag. It can't be a random commit.
Users can load the pre-trained models using torch.hub.load()
API.
.. automodule:: torch.hub
.. autofunction:: load
Here's an example loading resnet18
entrypoint from pytorch/vision
repo.
hub_model = hub.load(
'pytorch/vision:master', # repo_owner/repo_name:branch
'resnet18', # entrypoint
1234, # args for callable [not applicable to resnet]
pretrained=True) # kwargs for callable
The locations are used in the order of
- hub_dir: user specified path. It can be set in the following ways:
- Setting the environment variable
TORCH_HUB_DIR
- Callinghub.set_dir(<PATH_TO_HUB_DIR>)
~/.torch/hub
.. autofunction:: set_dir
By default, we don't clean up files after loading it. Hub uses the cache by default if it already exists in hub_dir
.
Users can force a reload by calling hub.load(..., force_reload=True)
. This will delete
the existing github folder and downloaded weights, reinitialize a fresh download. This is useful
when updates are published to the same branch, users can keep up with the latest release.